深度学习-*-RNN正向及反向传播

版权声明:原创文章未经博主允许不得转载O(-_-)O!!! https://blog.csdn.net/u013894072/article/details/84502501

RNN简介

RNN(循环神经网络)是深度神经网络中,应用最广泛的两种神经网络架构之一。并且,作为一种时序结构的神经网络,RNN经常用于时序相关的问题中,且在NLP中应用广泛。还有一种RNN称为递归神经网络,虽然名字类似,但是却是不一样的架构。

RNN图示

RNN结构图
x t x_t 是输入层数据, s t s_t 是隐含层数据, o t o_t 是输出层数据,我们令:每一个 y t y_t 是t时刻对应的真实输出, y t h a t y^{hat}_t 是对 o t o_t 进行softmax计算之后得到的估计值。 U U 是输入层到隐含层的权重, W W 是上一时刻隐含层到当前时刻隐含层的权重, V V 是隐含层到输出层的权重。

正向传播

由上图易知: a t = b + W s t 1 + U x t a_t=b+W*s_{t-1}+U*x_t s t = t a n h ( a t ) s_t=tanh(a_t) o t = c + U s t o_t=c+U*s_t y t h a t = s o f t m a x ( o t ) y^{hat}_t=softmax(o_t)
我们假设t时候的损失函数为 L t L^t (一般为交叉熵损失/负对数似然),则一次正向传播的损失 L = t L t L=\sum_tL^t

反向传播

反向传播中,还是使用链式推导方法,与传统的神经网络推导类似。但不一样的地方在于隐含层受到了前一时刻隐含层的影响,故 t t 时刻隐含层 s t s_t 的误差传播源来自于 o t o_t s t + 1 s_{t+1} 两个方向。这里推导我是参考了很多博客文章,但是一直都没理解。后来看了文献1,多少有点明白的意思。有幸各位大牛们看了这篇文章,请指点。
我们首先看误差对 o t o_t 的影响 o t L = L o t = L t o t = y t y t h a t I i = j y t \nabla o_tL=\frac{\partial L}{\partial o_t}=\frac{\partial L^t}{\partial o_t}=y_t*y^{hat}_t-I_{i=j}*y_t 其中i是当前数据所属真实类别索引,j为所有类别的索引分量。当i=j时, I i = j I_{i=j} 是1,否则是0,参考了文献2。
假设总时刻长度为 t = τ t=\tau , s t L = V T o t L t = τ \nabla s_tL = V^T*\nabla o_tL,t=\tau s t L = ( s t + 1 L s t L ) s t + 1 L + ( o t L s t L ) o t L t < τ \nabla s_tL=(\frac{\partial s_{t+1}L}{\partial s_tL})*\nabla s_{t+1}L + (\frac{\partial o_{t}L}{\partial s_tL})*\nabla o_{t}L,t<\tau
也就是说最后一个节点的隐含层误差只来源于他的输出层。其余各层除了本身输出层外,还会有上一层的误差来源。通过链式求导有
s t L = W T s t + 1 L d i a g ( 1 s t + 1 2 ) + V T o t L t < τ d i a g 线 \nabla s_tL=W^T*s_{t+1}L*diag(1-s_{t+1}^2)+V^T*\nabla o_tL,t<\tau,diag是对角线矩阵
故各种变量的梯度值为所有时刻梯度值的和:
c L = t o t L \nabla _cL=\sum_t \nabla o_tL b L = t d i a g ( 1 s t 2 ) o t L \nabla _bL=\sum_t diag(1-s_t^2)\nabla o_tL V L = t o t L s t T \nabla _VL=\sum_t \nabla o_tL *s_t^T W L = t d i a g ( 1 s t 2 ) s t L s t 1 T \nabla _WL=\sum_t diag(1-s_t^2)*\nabla s_tL*s_{t-1}^T U L = t d i a g ( 1 s t 2 ) s t L x t T \nabla _UL=\sum_t diag(1-s_t^2)*\nabla s_tL*x_{t}^T

参考文献

1.深度学习(AI圣经) P327
2.softmax函数及其导数
3.RNN求解过程推导与实现

猜你喜欢

转载自blog.csdn.net/u013894072/article/details/84502501
今日推荐