CNN笔记(CS231N)——循环神经网络(Recurrent Neural Networks)

递归神经网络

上一讲讲了CNN的架构,那么当我们把时间这个维度考虑进来了以后,我们就得到了递归神经网络(RNN)。RNN的输入输出可以是一对多、多对一、多对多,分别对应不同的应用场景

RNN的核心部分是如下公式,旧状态+当前输入,经过一个函数,得到了新状态,新状态会被送到下一个时候参与运算。我们的这个函数fw在不同时间是固定的。

普通RNN

一般来说这个fw函数是tanh函数,W是我们需要学习的权重,分别与h和x相乘

为了方便起见,我们一般会将RNN展开,这样更加方便我们的理解。由于每一个时刻都会有一个输出,因此我们在计算损失函数的时候会将每个时刻的损失函数集中在一起,然后再求梯度

一个常见的例子就是机器翻译,我们把一种语言翻译成另外一种语言,这种情况下我们采用的结构使一个多对一网络+一个一对多网络。我们把这个多对一网络叫做编码器,一对多网络叫做解码器,中间生成的是一个特征向量,来表征这个句子的特征。

我们再举一个词汇预测的例子,我们需要一个网络来对单词进行预测,那么我们要做的就是在每个时刻输出下一个字母的概率大小

我们在测试输出的时候会将输出的值经过softmax,再依据概率值进行输出采样。我们在这里并不直接采用输出概率的最大值,而是采用采样的方法,原因是我们可以增加输出的多样性,来保证我们有多种结果供我们选择

当我们训练序列很长的时候我们需要很大的网络来计算损失函数,这种情况下我们在进行反向传播的时候需要的计算量非常大

因此我们采取与SGD类似的概念将网络划分成许多子网络分批计算损失函数,再进行反向传播来对最终结果进行近似

我们可以用这个网络做一些很有意思的事情,例如我们让网络去学习写C代码。一件非常有意思的事情是尽管网络不知道任何C语言规则,而仅仅是通过已有训练数据的前后连接关系,就能预测出一段非常近似于C语言的代码

如果我们深入去看网络内部的状态值,我们可以看出网络确实在学习C语言的某些特征。例如有的状态单元学习了if语句,有的学习了代码的深度等特征。以下红色部分表示某个单元的值在当前输入下值很大,蓝色表示很小

图像理解

我们还可以将CNN结合RNN来做图像理解

步骤就是我们删除一个预训练网络的最后的输出层,保留全连接层,将全连接层的特征向量作为第一个状态输入RNN,再用RNN来输出图像描述的句子

我们还可以对某幅图像的特定区域进行特殊关注,来增强整体的描述效果

主要思路就是我们在每个时刻不仅输出描述的词汇,还输出一个分布函数,将这个函数作为一个mask与CNN输出的特征图相乘,得到一个有权特征图,代表我们在当前时刻关注图像的某一特定部分

以下是一个测试的实例,我们可以看到当网络把注意点放在特定区域的时候,确实输出了相应的词汇

视觉问答

RNN另外的应用场景是视觉问答

主要的思路也是让网络关注图像的不同部分输出相应的词汇

LSTM

除了普通RNN,我们还可以采用多层RNN的网络结构

接下来让我们看看普通RNN存在的问题,我们在进行反向传播的时候一个RNN的传播路径如下

那么当我们一级级计算梯度的时候,根据以前得到的结论,梯度值会乘以很多个W,这样导致的问题就是如果W矩阵中最大的值大于1,那么我们最终就会发生梯度爆炸;如果最大的值小于1,,那么我们最终就会发生梯度消失。对于梯度爆炸,我们采用的方法就是梯度剪切,也就是当梯度值大于某个阈值的时候我们对其进行限制;而梯度消失,我们只能对我们的网络结构进行改进,这也就是LSTM的提出初衷

LSTM是长短期记忆网络,它将本来的一个状态单元输出分成了两个:叫做cell状态和hidden状态,cell状态不传递给下个时间,hidden状态直接输出并作为参数传递给下个时间,而hidden状态是在cell状态基础上计算得到的。在这个网络中还有一些新的变量:遗忘门、输入门、门门、输出门。门门与普通RNN网络的状态单元功能类似,计算当前时刻的状态输出,它与输入门相乘,输入门控制这部分值有多少进入cell状态中;遗忘门控制之前的cell状态有多少进入当前状态中,上一个状态值与当前状态值相加得到cell状态。最后输出门来控制cell状态有多少暴露给下一个状态。我们可以看到i、f、o都是经过sigmoid函数,而g是经过tanh函数。由于g相当于是普通RNN的状态单元,因此采用tanh函数很好理解,而对于其他三个我们将他们进行二值化理解,0代表当前状态不暴露给下一个状态,1代表当前状态暴露给下一个状态,由此来取得更大的自由度,最终获得更好的效果

在进行反向传播的时候,尽管求梯度的公式比较复杂,但是我们能看出从当前状态到上一个状态有一条直接的通路,这个过程中梯度只需要乘以一个f即可,而每个单元的f是不同的,也就减轻了梯度消失的问题

这样就相当于是从输出到输入有一条直通的路,这种想法与ResNet是非常类似的

以下就是本讲的总结

猜你喜欢

转载自blog.csdn.net/shanwenkang/article/details/87199122