个人写作笔记,如有问题,请不吝赐教!
目录如下:
上一篇中矩阵迹、矩阵范数的导数部分还留白白斩斩的,这里补充一下。
这里是上一篇的链接
矩阵求导
还差一个最普适的情况——“矩阵对矩阵求导”没有分析,这种情况在神经网络中貌似特别常见。这里对这种情况的公式进行证明,证明方法与证明列向量的方式是完全一样的,不同之处在于,列向量在证明第二项时可以提取公因式产生矩阵分块,而在矩阵对矩阵求导中,不会得到矩阵分块,只能是矩阵直积,具体过程如下。
设矩阵 A∈Cm×l,B∈Cl×n,W∈Cp×q,证明:
dWdAB=dWdA(B⊗Iq)+(A⊗Ip)dWdB
根据矩阵对矩阵求导的定义,有:
dWdAB=dXd∑a1ibi1dXd∑a2ibi1⋮dXd∑amibi1dXd∑a1ibi2dXd∑a2ibi2⋮dXd∑amibi2⋯⋯⋯dXd∑a1ibindXd∑a2ibin⋮dXd∑amibin=∑dXda1ibi1∑dXda2ibi1⋮∑dXdamibi1∑dXda1ibi2∑dXda2ibi2⋮∑dXdamibi2⋯⋯⋯∑dXda1ibin∑dXda2ibin⋮∑dXdamibin+∑a1idXdbi1∑a2idXdbi1⋮∑amidXdbi1∑a1idXdbi2∑a2idXdbi2⋮∑amidXdbi2⋯⋯⋯∑a1idXdbin∑a2idXdbin⋮∑amidXdbin
考虑前一项因子式,将其分解为以下形式:
∑dXda1ibi1∑dXda2ibi1⋮∑dXdamibi1∑dXda1ibi2∑dXda2ibi2⋮∑dXdamibi2⋯⋯⋯∑dXda1ibin∑dXda2ibin⋮∑dXdamibin=∑dXda1i∑dXda2i⋮∑dXdami∑dXda1i∑dXda2i⋮∑dXdami⋯⋯⋯∑dXda1i∑dXda2i⋮∑dXdami⋅b110⋮0b210⋮0⋮⋮bm10⋮00b11⋮00b21⋮0⋮⋮0bm1⋮0⋯⋯⋯⋯⋯⋯⋯⋯⋯00⋮b2100⋮b21⋮⋮00⋮bm1∣∣∣∣∣∣∣∣∣∣∣∣∣∣b120⋮0b220⋮0⋮⋮bm20⋮00b12⋮00b22⋮0⋮⋮0bm2⋮0⋯⋯⋯⋯⋯⋯⋯⋯⋯00⋮b2200⋮b22⋮⋮00⋮bm2∣∣∣∣∣∣∣∣∣∣∣∣∣∣⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯∣∣∣∣∣∣∣∣∣∣∣∣∣∣b1n0⋮0b2n0⋮0⋮⋮bmn0⋮00b1n⋮00b2n⋮0⋮⋮0bmn⋮0⋯⋯⋯⋯⋯⋯⋯⋯⋯00⋮b2n00⋮b2n⋮⋮00⋮bmn=∑dXda1i∑dXda2i⋮∑dXdami∑dXda1i∑dXda2i⋮∑dXdami⋯⋯⋯∑dXda1i∑dXda2i⋮∑dXdami⋅(B⊗Iq)=dWdA(B⊗Iq)
考虑后一项因子式,同样地有类似的分解方法:
∑dXdbi1∑dXdbi1⋮∑amidXdbi1∑dXdbi2∑dXdbi2⋮∑amidXdbi2⋯⋯⋯∑dXdbin∑dXdbin⋮∑amidXdbin=a110⋮0a210⋮0⋮⋮am10⋮00a11⋮00a21⋮0⋮⋮0am1⋮0⋯⋯⋯⋯⋯⋯⋯⋯⋯00⋮a2100⋮a21⋮⋮00⋮am1∣∣∣∣∣∣∣∣∣∣∣∣∣∣a120⋮0a220⋮0⋮⋮am20⋮00a12⋮00a22⋮0⋮⋮0am2⋮0⋯⋯⋯⋯⋯⋯⋯⋯⋯00⋮a2200⋮a22⋮⋮00⋮am2∣∣∣∣∣∣∣∣∣∣∣∣∣∣⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯⋯∣∣∣∣∣∣∣∣∣∣∣∣∣∣a1n0⋮0a2n0⋮0⋮⋮amn0⋮00a1n⋮00a2n⋮0⋮⋮0amn⋮0⋯⋯⋯⋯⋯⋯⋯⋯⋯00⋮a2n00⋮a2n⋮⋮00⋮amn⋅∑dXdbi1∑dXdbi1⋮∑dXdbi1∑dXdbi2∑dXdbi2⋮∑dXdbi2⋯⋯⋯∑dXdbin∑dXdbin⋮∑dXdbin=(A⊗Ip)dWdB
综上,原公式得到了证明。
矩阵的链式求导法则
不同于纯量的导数,矩阵链式求导法则会涉及到矩阵对矩阵的导数,因此在介绍完上一部分后引入矩阵的链式求导法则。设有函数矩阵:
G(F)=g11(F)g21(F)⋮gm1(F)g12(F)g22(F)⋮gm2(F)⋯⋯⋯g1n(F)g2n(F)⋮gmn(F)
其中矩阵F∈Cs×t又是关于A∈Cp×q的函数矩阵,可以证明,在由G,F合成的增广线性空间中存在这样的一个算子φ:Csp×qt↦Cm×n,使之对应的矩阵求导的链式法则,但是这个形式过于复杂,且仅具有理论上的作用,因此这里就不推导了。下面推导部分经常被使用的链式法则。
点乘运算的函数
点乘运算是特殊的函数矩阵,其导数在神经网络的反向传播中有重要作用。点乘运算的导数是一种特殊的链式法则。若矩阵A,C∈Cm×n,则定义点乘运算为:
A⊙B=a11b11a21b21⋮am1bm1a12b12a22b22⋮am2bm2⋯⋯⋯a1nb1na2nb2n⋮amnbmn
若有求导变量W∈Cm×n点乘运算的求导结果为:
dWdA⊙B=dWdA⊙[B⊗1(W)]+[A⊗1(W)]⊙dWdB
其中:1(W) 表示与W形状相同的全1矩阵。
按元素的函数运算
按元素的函数运算是指对矩阵内每一个元素均使用同一函数映射,从而改变整个矩阵的值,该运算的导数是特殊的链式运算法则。对于矩阵A∈Cm×n和单射f:x∈C↦C,定义按元素的函数运算为:
f⊙(A)≜fA≜f(a11)f(a21)⋮f(am1)f(a12)f(a22)⋮f(am2)⋯⋯⋯f(a1n)f(a2n)⋮f(amn)
若有求导变量W∈Cm×n,则按元素函数运算的求导结果为:
dWdf⊙(A)=[dxdfA⊗1(W)]⊙dWdA
其中:
dxdfA≜dxdfa11dxdfa21⋮dxdfam1dxdfa12dxdfa22⋮dxdfam2⋯⋯⋯dxdfa1ndxdfa2n⋮dxdfamn
上式表明,如果元素内存在函数的嵌套关系,同样可以如纯量函数链式法则一样展开,即:
dWdg∘f⊙(A)=[dxdgf⊙(A)⊗1(W)]⊙[dxdfA⊗1(W)]⊙dWdA
矩阵元素算子求导
将上式做适当推广即可得到矩阵元素算子的导数,定义:
F⊙(A)≜FA≜f11(a11)f21(a21)⋮fm1(am1)f12(a12)f22(a22)⋮fm2(am2)⋯⋯⋯f1n(a1n)f2n(a2n)⋮fmn(amn)
若有求导变量W∈Cm×n,则按元素函数运算的求导结果为:
dWdF⊙(A)=[dxdFA⊗1(W)]⊙dWdA
其中:
dxdFA≜dxdf11a11dxdf21a21⋮dxdfm1am1dxdf12a12dxdf22a22⋮dxdfm2am2⋯⋯⋯dxdf1na1ndxdf2na2n⋮dxdfmnamn
矩阵的迹
设矩阵 A,B∈Cn×n,W∈Cp×q
定义
对于方阵而言,定义矩阵的迹为主对角线上全体元素之和,即:
tr(A)=i=1∑naii
性质
讨论矩阵的迹时,只考虑方阵的情况,一般不是方阵不必讨论迹。矩阵的迹等于矩阵全体特征值之和,可由多项式韦达定理证得,即有:
tr(A)=i=1∑nλi
矩阵的迹满足乘积可换顺序,即:
tr(AB)=tr(BA)
微分性质
矩阵的迹就是一个纯量,因此求解时按一般纯量的求法即可。对于含有矩阵运算的迹,追迹计算与导数运算不能交换顺序:
dWd[tr(A)]=i=1∑ndWdaii
从上式也可看出,该式并不能得到反映矩阵A全体元素的表达式,因此不能交换导数和追迹的顺序。但是两矩阵乘积嵌套迹运算,可以简记为如下形式:
dWd[tr(AB)]=[tr(dwijdAB+AdwijdB)]n×n
矩阵的范数
∀A∈Cm×n,定义以下的范数:
矩阵原生范数
总和范数:∣∣A∣∣M=j=1∑ni=1∑m∣aij∣
F范数:∣∣A∣∣F=j=1∑ni=1∑m∣aij∣2
G范数:∣∣A∣∣G=n⋅i,jmaxi=1∑n∣aij∣
向量范数导出的矩阵范数
矩阵的最大奇异值称为矩阵的谱半径,用 ρ(A)=max{si}i=1n 表示。
行和范数:∣∣A∣∣∞=jmaxj=1∑n∣aij∣
列和范数:∣∣A∣∣1=jmaxi=1∑m∣aij∣
谱范数:∣∣A∣∣2=max{si}i=1n
矩阵范数的导数
矩阵范数大多不能直接求导,常见的能求导的范数有F范数,F范数的求法可由下式快捷算出:
∣∣A∣∣F2=tr(ATA)
简单起见,此处对范数进行平方,在常见的学习率正则化处理中,经常见到带有这种形式的误差项,对此平方求导,可得:
dWd∣∣A∣∣F2=dWd[tr(ATA)]=tr(dWdATA)=tr[dWdAT(A⊗Iq)+(AT⊗Ip)dWdA]
特别地,当 W∈Cm×m 即微分变量也是方阵时,根据追迹的乘积互换性,可以一步得到:
dWd∣∣A∣∣F2=tr[dWdAT(A⊗Iq)+(AT⊗Ip)dWdA]=2tr[dWdAT(A⊗Iq)]
向量范数的导数
将矩阵退化为列向量,就可以得到向量范数导数,这种情况更为常见。向量的二范数定义为:
∣∣X∣∣22=XTX
因此对于列向量 Y∈Cm×1 和 X∈Cn×1 ,有:
dXd∣∣Y∣∣22=dXdYT(Y⊗I1)+(YT⊗Im)dXdY=dXdYTY+V(YTdXTdY)=2⋅dXdYTY
常见矩阵导数表
设矩阵A,B,D,⋯∈C,列向量X,Y,Z∈Cq×1。
线性变换
dXdAXdXTdAXdXdXdAdAXdATdAm×qXdAdAB=V(A)=A=V(Iq)=V(Im)VT(XT)=X⊗Im=V(Im)VT(BT)
二次型
dXdXTAXdXTdXTAXdAdXTAXdAdXTAY=(A+AT)X=XT(A+AT)=(XT⊗Iq)dAdA(X⊗Iq)=XXT=XYT
迹运算型
利用迹运算的压缩特性和互换性,可以对某些范数对矩阵的导数进行化简,范数对向量的导数可以直接按求导乘积法则运算。如下所示:
dAd∣∣A∣∣F2dAd[tr(AB)]dAd∣∣AX∣∣22=[tr(daijdATA+ATdaijdA)]n×n=A+AT=dAd[tr(BA)]=BT=[tr(daijdXTATAX+XTATdaijdAX)]n×n=XT(A+AT)X
备注:以下公式可为上述各式化简提供简化手段:
tr(dwijdAB)[tr(dwijdAB)]n×n=⎩⎨⎧dwijdaijbji0i=ji=j=BT