传统RNN模型

传统RNN的内部结构图

在这里插入图片描述
ht ht-1 Xt+1代表隐藏层的输出
xt xt-1 xt+1 时间步的输入
输出计算公式:
ht=tanh([ Wt [xt,xh-1]+bt)
解释:输出ht是t时间步输入xt和t-1时间步的输出ht-1拼接,输入到全连接层,经过非线性激活函数tanh输出

  • 缺点:
    在解决长序列之间的关联时, RNN表现很差, 原因是在进行反向传播的时候, 过长的序列导致梯度的计算异常, 发生梯度消失或爆炸
Pytorch调用RNN使用
>>> import torch
>>> import torch.nn as nn
>>> rnn=nn.RNN(2,3,2)
>>> input=torch.randn(1,4,2)
>>> ht=torch.randn(2,4,3)
>>> output,hn=rnn(input,ht)
>>> output
tensor([[[-0.7655, -0.3165,  0.8719],
         [ 0.3716, -0.0272,  0.7841],
         [ 0.3092,  0.6021,  0.0438],
         [-0.5681, -0.0064,  0.8295]]], grad_fn=<StackBackward>)
>>> hn
tensor([[[-0.9420, -0.9532, -0.6926],
         [-0.1699, -0.3431,  0.3370],
         [-0.5715,  0.0033,  0.3483],
         [-0.9707, -0.6465, -0.6308]],

        [[-0.7655, -0.3165,  0.8719],
         [ 0.3716, -0.0272,  0.7841],
         [ 0.3092,  0.6021,  0.0438],
         [-0.5681, -0.0064,  0.8295]]], grad_fn=<StackBackward>)
发布了66 篇原创文章 · 获赞 1 · 访问量 7019

猜你喜欢

转载自blog.csdn.net/qq_41128383/article/details/105519328