通过时间反向传播
本文基于《动手学深度学习》一书,给出了对应章节相对详细的推导。
一、RNN的反向传播推导
1.问题描述
这是RNN网络的t时刻的关系式:
{ h t = W h x x t + W h h h t − 1 O t = W q h h t \left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{array} \right . {
ht=Whxxt+Whhht−1Ot=Wqhht
设有损失函数
L = 1 T ∑ t = 1 n l ( O t , y t ) L = \frac{1}{T}\sum_{t=1}^{n}l(O_t, y_t) L=T1t=1∑nl(Ot,yt)
欲求
∂ L ∂ W q h , ∂ L ∂ W h x , ∂ L ∂ W h h \frac{\partial L}{\partial W_{qh}}, \frac{\partial L}{\partial W_{hx}}, \frac{\partial L}{\partial W_{hh}} ∂Wqh∂L,∂Whx∂L,∂Whh∂L
一些准备: 矩阵的链式求导和基本求导法则与原理是需要掌握的。
2.问题求解
首先,求解 ∂ L ∂ W q h \frac{\partial L}{\partial W_{qh}} ∂Wqh∂L
对于任意时刻 t t t ,显然有:
∂ L ∂ O t = 1 T ⋅ ∂ l ( O t , y t ) ∂ O t d l = t r ( ( ∂ l ∂ O t ) T ⋅ d O t ) O t = W q h h t \frac{\partial L}{\partial O_t} = \frac{1}{T} \cdot \frac{\partial l(O_t, y_t)}{\partial O_t} \\ \mathrm{d}l = tr\left( {\left( \frac{\partial l}{\partial O_t} \right)}^T \cdot \mathrm{d}O_t \right) \\ O_t = W_{qh}h{t} ∂Ot∂L=T1⋅∂Ot∂l(Ot,yt)dl=tr((∂Ot∂l)T⋅dOt)Ot=Wqhht
因此,将 O t O_t Ot 带入微分式中,有:
d L = t r ( ∑ i = 1 T ( ∂ l ∂ O t ) T d W q h ⋅ h t ) \mathrm{d}L = tr\left( \sum_{i=1}^{T}{\left( \frac{\partial l}{\partial O_t} \right)}^T \mathrm{d}W_{qh} \cdot h_t \right) dL=tr(i=1∑T(∂Ot∂l)TdWqh⋅ht)
将 h t h_t ht 放到迹的右方,有:
d L = t r ( ∑ i = 1 T h t ( ∂ l ∂ O t ) T d W q h ) \mathrm{d}L = tr\left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \mathrm{d}W_{qh} \right) dL=tr(i=1∑Tht(∂Ot∂l)TdWqh)
因此:
∂ L ∂ W q h = ( ∑ i = 1 T h t ( ∂ l ∂ O t ) T ) T = ∑ i = 1 T ∂ l ∂ O t ( h t ) T \frac{\partial L}{\partial W_{qh}} = \left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \right)^T = \sum_{i=1}^{T} \frac{\partial l}{\partial O_t} {\left( h_t \right)}^T ∂Wqh∂L=(i=1∑Tht(∂Ot∂l)T)T=i=1∑T∂Ot∂l(ht)T
接下来我们尝试求解 ∂ L ∂ W h x , ∂ L ∂ W h h \frac{\partial L}{\partial W_{hx}},\frac{\partial L}{\partial W_{hh}} ∂Whx∂L,∂Whh∂L
先从T时刻开始求解(这里的prod()表示了矩阵链式求导的法则):
我们首先有:
{ h t = W h x x t + W h h h t − 1 O t = W q h h t \left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{array} \right . {
ht=Whxxt+Whhht−1Ot=Wqhht
∂ L ∂ h T = p r o d ( ∂ L ∂ O T , ∂ O T ∂ h T ) \frac{\partial L}{\partial h_T} = prod\left( \frac{\partial L}{\partial O_T}, \frac{\partial O_T}{\partial h_T} \right) ∂hT∂L=prod(∂OT∂L,∂hT∂OT)
对于T-1时刻,有
∂ L ∂ h T − 1 = p r o d ( ∂ L ∂ O T − 1 , ∂ O T − 1 ∂ h T − 1 ) + p r o d ( ∂ L ∂ h T , ∂ h T ∂ h T − 1 ) \frac{\partial L}{\partial h_{T-1}} = prod\left( \frac{\partial L}{\partial O_{T-1}}, \frac{\partial O_{T-1}}{\partial h_{T-1}} \right) + prod\left( \frac{\partial L}{\partial h_T}, \frac{\partial h_T}{\partial h_{T-1}} \right) ∂hT−1∂L=prod(∂OT−1∂L,∂hT−1∂OT−1)+prod(∂hT∂L,∂hT−1∂hT)
…
同理,对于t时刻, t < T,有:
∂ L ∂ h t = p r o d ( ∂ L ∂ O t , ∂ O t ∂ h t ) + p r o d ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) \frac{\partial L}{\partial h_t} = prod\left( \frac{\partial L}{\partial O_t}, \frac{\partial O_t}{\partial h_t} \right) + prod\left( \frac{\partial L}{\partial h_{t+1}}, \frac{\partial h_{t+1}}{\partial h_t} \right) ∂ht∂L=prod(∂Ot∂L,∂ht∂Ot)+prod(∂ht+1∂L,∂ht∂ht+1)
求偏导方式如上求解 ∂ L ∂ W q h \frac{\partial L}{\partial W_{qh}} ∂Wqh∂L 时使用的 化矩阵迹链式求导方法 所示,得到:
∂ L ∂ h t = W q h T ∂ L ∂ O t + W h h T ∂ L ∂ h t + 1 \frac{\partial L}{\partial h_t} = W_{qh}^T \frac{\partial L}{\partial O_t} + W_{hh}^T \frac{\partial L}{\partial h_{t+1}} ∂ht∂L=WqhT∂Ot∂L+WhhT∂ht+1∂L
打开该递归公式可得:
∂ L ∂ h t = ∑ i = t T ( W h h T ) T − i W q h T ∂ L ∂ O T + t − i \frac{\partial L}{\partial h_t} = \sum_{i=t}^T \left( W_{hh}^T \right)^{T-i} W_{qh}^T \frac{\partial L}{\partial O_{T+t-i}} ∂ht∂L=i=t∑T(WhhT)T−iWqhT∂OT+t−i∂L
所以
∂ L ∂ W h x = p r o d ( ∂ L ∂ h t , ∂ h t ∂ W h x ) ∂ L ∂ W h h = p r o d ( ∂ L ∂ h t , ∂ h t ∂ W h h ) \frac{\partial L}{\partial W_{hx}} = prod\left( \frac{\partial L}{\partial h_t}, \frac{\partial h_t}{\partial W_{hx}} \right) \\ \frac{\partial L}{\partial W_{hh}} = prod\left( \frac{\partial L}{\partial h_t}, \frac{\partial h_t}{\partial W_{hh}} \right) ∂Whx∂L=prod(∂ht∂L,∂Whx∂ht)∂Whh∂L=prod(∂ht∂L,∂Whh∂ht)
继而有(此处的prod链式法则同上,请自行计算):
∂ L ∂ W h x = ∑ t = 1 T ∂ L ∂ h t x t T ∂ L ∂ W h h = ∑ t = 1 T ∂ L ∂ h t h t − 1 T \frac{\partial L}{\partial W_{hx}} = \sum_{t=1}^T\frac{\partial L}{\partial h_t}x_t^T \\ \frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T\frac{\partial L}{\partial h_t}h_{t-1}^T ∂Whx∂L=t=1∑T∂ht∂LxtT∂Whh∂L=t=1∑T∂ht∂Lht−1T
再加上之前求解的:
∂ L ∂ W q h = ∑ i = 1 T ∂ l ∂ O t ( h t ) T \frac{\partial L}{\partial W_{qh}} = \sum_{i=1}^{T} \frac{\partial l}{\partial O_t} {\left( h_t \right)}^T ∂Wqh∂L=i=1∑T∂Ot∂l(ht)T
至此RNN的反向传播推导完毕。
二、LSTM的反向传播推导
1.问题描述
I t = σ ( W x i X t + W h i H t − 1 + b i ) F t = σ ( W x f X t + W h f H t − 1 + b f ) O t = σ ( W x o X t + W h o H t − 1 + b o ) C t ′ = t a n h ( W x c X t + W h c H t − 1 + b c ) C t = F t ⊙ C t − 1 + I t ⊙ C t ′ H t = O t ⊙ t a n h ( C t ) Y t = W q h H t + b q \begin{array}{ll} I_t=\sigma\left( W_{xi}X_t + W_{hi}H_{t-1} + b_i \right) \\ F_t=\sigma\left( W_{xf}X_t + W_{hf}H_{t-1} + b_f \right) \\ O_t=\sigma\left( W_{xo}X_t + W_{ho}H_{t-1} + b_o \right) \\ C_t^{'}=\mathrm{tanh}\left( W_{xc}X_t + W_{hc}H_{t-1} + b_{c} \right) \\ C_t=F_t \odot C_{t-1} + I_t \odot C_t^{'} \\ H_t=O_t \odot \mathrm{tanh}(C_t) \\ Y_t=W_{qh}H_t + b_q \end{array} It=σ(WxiXt+WhiHt−1+bi)Ft=σ(WxfXt+WhfHt−1+bf)Ot=σ(WxoXt+WhoHt−1+bo)Ct′=tanh(WxcXt+WhcHt−1+bc)Ct=Ft⊙Ct−1+It⊙Ct′Ht=Ot⊙tanh(Ct)Yt=WqhHt+bq