RNN训练与BP算法

摘要:


本文主要讲述的RNN(循环神经网络)在训练过程中BP算法的推导。
在阅读本文之前希望读者看过我的另一篇文章BP算法心得体会。因为大部分的思路沿用的就是这篇文章的思路。
参考文章:
数学推导-1
数学推导-2

更新-2018-01-23:
之前写完这篇文章之后,回头看了一遍文章,发现在整个推导的过程都无视了时间维度的存在,所以后来查阅了相关的资料,发现目前网上有一部分RNN的推导过程和本文是一样的,比如上面给到的2篇参考文章,思路和本文是一致的。但是也存在另外一些版本的推导,其过程和本文的截然不同。
所以后来在参考了大神的代码后,重新思考了rnn的训练算法,因此重新写一个篇rnn和bptt供大家参考。

正文


RNN的一般原理介绍这里就不再重复了,本文关注的是RNN是如何利用BP算法来进行训练的。

推导


在推导BP算法之前,我们先做一些变量上的规定,这一步非常关键。
本文使用的RNN是只含一个隐藏层(多个隐藏层其实也是一样的道理)。其结构如下图所示:
这里写图片描述
(大家看到这个网络结构可能有些困惑,比如说,RNN是由多个网络组成的吗?这里值得注意的是,RNN就只由一个网络组成,图上有多个网络是在不同时刻的输入下的网络情况)
现在,作如下的一些规定:
vim 是输入层第 m 个输入与隐藏层中第 i 个神经元所连接的权重。
uin 是隐层自循环的权重(具体表现为上面结构图中那些紫色、绿色的线)
wkm 是隐藏层中第m个神经元与输出层第k个神经元连接的权重。
网络中共有 N(i) N(h) N(o)

netthi 表示隐藏层第 i 个神经元在 t 时刻激活前的输入。
具体为: netthi=N(i)m=1(vimxtm)+N(h)s=1(uisht1s)
经过激活后的输出为: hti=f(netthi)

nettyk 表示输出层第 k 个神经元在 t 时刻激活前的输入。
具体为: nettyk=N(h)m=1(wkmhtm)
经过激活后的输出为: otk=f(nettyk)


这里同样地,为了方便推导,假设损失函数 Et=0.5N(o)k=1(otkttk)2 (本文也会说明使用其他损失函数的情况)
E=stept=1Et
首先我们需要解决的问题就是求出:
Euin Ewkm Evim

1.先来求最简单的 Ewkm
和之前讲解BP的文章套路一样,我们可以对 Ewkm 使用链式法则,具体如下:
Ewkm=Enettyknettykwkm
对于等式右边第二项很好计算, nettykwkm=htm
和之前一样,我们定义等式右边第一项为误差信号 δtyk=Enettyk
δtyk=Enettyk=Eotkotknettyk 。(这一步的思路就是找到和 nettyk 有直接相关的变量)
故: δtyk=Enettyk=Eotkotknettyk=(otkttk)f(nettyk)
所以, Ewkm=δtykhtm=(otkttk)f(nettyk)htm

下面,我们推导 Evim
Evim=Enetthinetthivim
对于 netthivim=xtm
定义误差信号 δthi=Enetthi
δthi=Enetthi=Ehtihtinetthi=Ehtif(netthi)
整个RNN如果说最麻烦的推导,可能就是对于 Ehti 的推导。

按照以前的思路我们容易想到:
Ehti=N(o)k=1(Enettyknettykhti)=N(o)k=1(Enettykwki)

上面的推导对吗?如果推导到这里感觉没问题的话不妨思考一个问题,如果公式是这样,这里哪里体现了RNN具有“记忆”的功能?公式体现的只与当前时刻 t 有关。

我们注意到,和 hti 有直接函数关系的除了 nettyk 以外,其实还有一条等式,而恰恰是这条等式把每个时刻之间的关系串了起来。
就是: netthi=N(i)m=1(vimxtm)+N(h)s=1(uisht1s)
我把上式中的 t -> t+1 ,也就是往后推一个时刻,我们有:
nett+1hi=N(i)m=1(vimxt+1m)+N(h)s=1(uishts)
也就是说, hti 还和 nett+1hi 相关。所以上式应该改写成:
Ehti=N(o)k=1(Enettyknettykhti)+N(h)s=1(Enett+1hsnett+1hshti)=N(o)k=1(Enettykwki)+N(h)s=1(Enett+1hsusi)

其实就是多了一项。这个是大家需要注意的!
对于 Enettyk 其为输出层的误差信号,上面已经求过了,即 δtyk

Enett+1hs 其实就是 δt+1hs 。这个就是 t+1 时刻的隐藏层的一个误差信号。而 t 时刻隐藏层的误差信号与 t+1 时刻隐藏层的误差信号有关,或者换句话说法, t 时刻的隐藏层的误差信号积累了 t+1 时刻的误差,看到这里,其实我们就可以认识到一个问题,RNN确实具有一定的记忆能力。

Ok,把上式整理一下可以得到:
Ehti=N(o)k=1(δtykwki)+N(h)s=1(δt+1hsusi)
由于: δthi=Enetthi=Ehtihtinetthi=Ehtif(netthi) ,替换掉 Ehti 得到:
δthi=(N(o)k=1(δtykwki)+N(h)s=1(δt+1hsusi))f(netthi)
Evim=δthixtm
最后一个是 Euin 。其实其和 Evim 是一样的,(因为位于同一层)具体可以参考下面:

Euin=Enetthinetthiuin ,可以看到等式右边第一项就是上面推导过的隐藏层误差信号 δthi ,而第二项就是 ht1n
所以: Euin=Enetthinetthiuin=δthiht1n

小结


至此,RNN的bp算法算是推导完毕,我们如果看回整个推导过程,其实和前面文章介绍的BP没什么区别,最大的区别在于RNN具有时序性,所以在隐藏层的误差信号处理时需要格外的注意,下面,我们可以从结构图来看待这一个问题,这种角度也可以加深我们对所谓“反向传播”有多一个深刻的理解。

这里写图片描述

这个是上面的结构图,我关注一下“紫色”的线,紫色线连接的是t时刻隐藏层和t+1时刻的隐藏层。我们从误差传播的角度来看。对于t时刻隐藏层某一个神经元而言,其误差可以分为两部分来源,第一部分就是t时刻本身的(黑色线,连接隐藏层和输出层这些),另外一部分就是t+1时刻时候隐藏层和隐藏层(自循环层)。
而这两部分恰恰对应了上面公式的两个部分。
公式中红色部分就是t时刻的误差,蓝色部分就是来自于t+1时刻。
δthi=(N(o)k=1(δtykwki)+N(h)s=1(δt+1hsusi))f(netthi)

下面再补一副很简陋的图,想表达的意思和上面一样。
这里写图片描述

关于训练过程的细节-1


这里可能有人会疑问,计算t时刻误差需要用到t+1时刻的误差,这个不是有背常理吗?这里需要注意的,神经网络里面是先前向计算,然后反向传播误差。所以每次训练,先从t=0时刻前向计算至最后一个时刻t。然后从t时刻反向传播误差。所以这里需要保存每一个时刻隐藏层、输出层的输出。

关于训练过程的细节-2


最后一个时刻由于没有下一个时刻传回来的隐藏层误差,所以下式中蓝色一项为0。
δthi=(N(o)k=1(δtykwki)+N(h)s=1(δt+1hsusi))f(netthi)
即:
δthi=N(o)k=1(δtykwki)

关于损失函数


和之前BP算法推导一样,其实损失函数就只有在这一步中产生影响。
δtyk=Enettyk=Eotkotknettyk
完全可以保留 Eotk 这个记号,并不影响后面的计算。

至此,两篇关于BP算法的文章算是告一段落,希望大家能够从中学习到东西。

猜你喜欢

转载自blog.csdn.net/qq_22238533/article/details/79079898
今日推荐