传统RNN的内部结构图
ht ht-1 Xt+1代表隐藏层的输出
xt xt-1 xt+1 时间步的输入
输出计算公式:
ht=tanh([ Wt [xt,xh-1]+bt)
解释:输出ht是t时间步输入xt和t-1时间步的输出ht-1拼接,输入到全连接层,经过非线性激活函数tanh输出
- 缺点:
在解决长序列之间的关联时, RNN表现很差, 原因是在进行反向传播的时候, 过长的序列导致梯度的计算异常, 发生梯度消失或爆炸
Pytorch调用RNN使用
>>> import torch
>>> import torch.nn as nn
>>> rnn=nn.RNN(2,3,2)
>>> input=torch.randn(1,4,2)
>>> ht=torch.randn(2,4,3)
>>> output,hn=rnn(input,ht)
>>> output
tensor([[[-0.7655, -0.3165, 0.8719],
[ 0.3716, -0.0272, 0.7841],
[ 0.3092, 0.6021, 0.0438],
[-0.5681, -0.0064, 0.8295]]], grad_fn=<StackBackward>)
>>> hn
tensor([[[-0.9420, -0.9532, -0.6926],
[-0.1699, -0.3431, 0.3370],
[-0.5715, 0.0033, 0.3483],
[-0.9707, -0.6465, -0.6308]],
[[-0.7655, -0.3165, 0.8719],
[ 0.3716, -0.0272, 0.7841],
[ 0.3092, 0.6021, 0.0438],
[-0.5681, -0.0064, 0.8295]]], grad_fn=<StackBackward>)