【TensorFlow】理解tf.nn.dynamic_rnn方法(附详细代码)

本文是在参考资料1的基础上加入更多细节完成,并非完全原创,感谢原创同学,尊重支持原创才能让社区更加健康。

这次在模型优化的时候加入了一个RNN结构,TensorFlow里有封装好的RNN函数,我们可以直接调用,RNN详细介绍见参考资料2

TensorFlow官网给的标准API:
注意: 这个是TF1.0版本下的,在2.0以上版本,dynamic_rnn是在 tf.compat.v1.nn.dynamic_rnn

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

参数说明:

  • cell: LSTM、GRU等的记忆单元。cell参数代表一个LSTM或GRU的记忆单元,也就是一个cell,是RNN中最小的单元结构。例如,cell = tf.nn.rnn_cell.LSTMCell((num_units),其中,num_units表示rnn cell中神经元个数,也就是下文的cell.output_size。返回一个LSTM或GRU cell,作为参数传入。多个cell组成了一个完整的RNN结构。

  • inputs: 输入的训练或测试数据,一般格式为[batch_size, max_time, embed_size],其中batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,embed_size表示嵌入的词向量的维度。

  • sequence_length: 是一个list,假设你输入了三句话,且三句话的长度分别是5,10,25,那么sequence_length=[5,10,25]。

  • time_major: 决定了输出tensor的格式,如果为True, 张量的形状必须为 [max_time, batch_size,cell.output_size]。如果为False, tensor的形状必须为[batch_size, max_time, cell.output_size],cell.output_size表示rnn cell中神经元个数。

  • 返回值:元组(outputs, states)

  • outputs: outputs很容易理解,就是每个cell会有一个输出

  • states: states表示最终的状态,也就是序列中最后一个cell输出的状态。一般情况下states的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state 的形状为 [2,batch_size, cell.output_size ],其中2也对应着 LSTM 中的 cell state 和 hidden state。

可以看到,tf.nn.dynamic_rnn 这个方法有两个返回值,outputsstates,那这俩是什么关系呢?另外 state 的形状为什么会发生改变?

先回答第二个问题,为什么 state 的形状会发生改变?

首先看当 cell 是LSTM类型时,states形状为 [2,batch_size,cell.output_size ];当cell为GRU时,states形状为[batch_size, cell.output_size ]。其原因是因为 LSTM 和 GRU 的结构本身不同,如下面两个图所示,这是 LSTM 的 cell 结构,每个 cell 会有两个输出: C t C_{t} h t h_t ,上面这个图是输出 C t C_t ,代表哪些信息应该被记住哪些应该被遗忘; 下面这个图是输出 h t h_t ,代表这个cell的最终输出,LSTM的 states 是由 C t C_t h t h_t 组成的,即 states = (c, h)。
更新细胞状态

输出信息
当 cell 为 GRU 时,state 就只有一个了,原因是GRU将 C t C_t h t h_t 进行了简化,将其合并成了 h t h_t ,如下图所示,GRU将遗忘门和输入门合并成了更新门,另外 cell 不再有细胞状态 cell state,只有hidden state

在这里插入图片描述
再回答第一个问题,outputs 和 states,这俩是什么关系呢?

对于不同的 cell 类型,outputsstates 的关系是有差异的。
如果 cell 为 LSTM,那 states 是个 tuple,分别代表 C t C_t h t h_t ,其中 h t h_t 与outputs中对应的最后一个时刻(即最后一个cell)的输出相等,这里再细说一下,outputs 输出的是每个 cell 的 h,也就是说整个RNN结构里有多少个 cell,outputs 就有多少个 h值,而 states 输出的是最后一个cell 的 C 和 h,它是h 和 outputs 里最后一个h 值是一样的;如果cell为GRU,那么同理,states其实就是 h t h_t

Talk is cheap , show me your code

import tensorflow as tf
import numpy as np
 
def dynamic_rnn(rnn_type='lstm'):
    # 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time), 控制序列长度, 4代表每个序列的维度
    X = np.random.randn(3, 6, 4)
 
    # 第二个输入的实际长度为4
    X[1, 4:] = 0
 
    #记录三个输入的实际步长
    X_lengths = [6, 4, 6]
 
    rnn_hidden_size = 5
    if rnn_type == 'lstm':
        cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
    else:
        cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
 
    outputs, last_states = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        sequence_length=X_lengths,
        inputs=X)
 
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        o1, s1 = session.run([outputs, last_states])
        print(np.shape(o1))
        print("*"*20)
        print(o1)
        print("*"*20)
        print(np.shape(s1))
        print("*"*20)
        print(s1)
 
if __name__ == '__main__':
    dynamic_rnn(rnn_type='lstm')

cell类型为LSTM,输入的形状为 [ 3, 6, 4 ],经过 tf.nn.dynamic_rnnoutputs 的形状为 [ 3, 6, 5 ],states 形状为 [ 2, 3, 5 ],其中 state 第一部分为 c,代表 cell state,第二部分为 h,代表 hidden state,这就是形状里的第一维2的构成,3是 batch_size,因为我们一次性输入的是3条序列,5是每个输出向量的维度。可以看到 hidden state 与 对应的 outputs 的最后一行是相等的。另外需要注意的是输入一共有三个序列,但第二个序列的长度只有4,可以看到 outputs 中对应的两行值都为0,所以 hidden state 对应的是最后一个不为0的部分。tf.nn.dynamic_rnn 通过设置 sequence_length 来实现这一逻辑。

输出结果1:

(3, 6, 5)
********************
[[[ 0.0146346  -0.04717453 -0.06930042 -0.06065602  0.02456717]
  [-0.05580321  0.08770171 -0.04574306 -0.01652854 -0.04319528]
  [ 0.09087799  0.03535907 -0.06974291 -0.03757408 -0.15553619]
  [ 0.10003044  0.10654698  0.21004055  0.13792148 -0.05587583]
  [ 0.13547596 -0.014292   -0.0211154  -0.10857875  0.04461256]
  [ 0.00417564 -0.01985144  0.00050634 -0.13238986  0.14323784]]
 
 [[ 0.04893576  0.14289175  0.17957205  0.09093887 -0.0507192 ]
  [ 0.17696126  0.09929577  0.21185635  0.20386451  0.11664373]
  [ 0.15658667  0.03952745 -0.03425637  0.00773833 -0.03546742]
  [-0.14002582 -0.18578786 -0.08373584 -0.25964601  0.04090167]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]]
 
 [[ 0.18564152  0.01531695  0.13752453  0.17188506  0.19555427]
  [ 0.13703949  0.14272294  0.21313036  0.07417354  0.0477547 ]
  [ 0.23021792  0.04455495  0.10204565  0.17159792  0.34148467]
  [ 0.0386402   0.0387848   0.02134559  0.00110381  0.08414687]
  [ 0.01386241 -0.02629686 -0.0733538  -0.03194245  0.13606553]
  [ 0.01859433 -0.00585316 -0.04007138  0.03811594  0.21708331]]]
********************
(2, 3, 5)
********************
LSTMStateTuple(
       c=array([[ 0.00909146, -0.03747076,  0.0008946 , -0.23459786,  0.29565899],
                [-0.18409266, -0.30463044, -0.28033809, -0.49032542,  0.12597639],
                [ 0.04494702, -0.01359631, -0.06706629,  0.06766361,  0.40794032]]), 
       h=array([[ 0.00417564, -0.01985144,  0.00050634, -0.13238986,  0.14323784],
                [-0.14002582, -0.18578786, -0.08373584, -0.25964601,  0.04090167],
                [ 0.01859433, -0.00585316, -0.04007138,  0.03811594,  0.21708331]])
               )

cel l类型为 GRU,我们看看到,输入的形状为 [ 3, 6, 4 ],经过 tf.nn.dynamic_rnnoutputs 的形状为 [ 3, 6, 5 ],state形状为 [ 3, 5 ]。可以看到 state 与 对应的 outputs 的最后一行是相等的

输出结果2:

(3, 6, 5)
********************
[[[-0.05190962 -0.13519617  0.02045928 -0.0821183   0.28337528]
  [ 0.0201574   0.03779418 -0.05092804  0.02958051  0.12232347]
  [ 0.14884441 -0.26075898  0.1821795  -0.03454954  0.18424161]
  [-0.13854156 -0.26565378  0.09567164 -0.03960079  0.14000589]
  [-0.2605973  -0.39901657  0.12495693 -0.19295695  0.52423598]
  [-0.21596414 -0.63051687  0.20837501 -0.31775378  0.77519457]]
 
 [[-0.1979659  -0.30253523  0.0248779  -0.17981144  0.41815343]
  [ 0.34481129 -0.05256187  0.1643036   0.00739746  0.27384158]
  [ 0.49703664  0.22241165  0.27344766  0.00093435  0.09854949]
  [ 0.23312444  0.156997    0.25482553  0.0138156  -0.02302272]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]]
 
 [[-0.06401732  0.08605342 -0.03936866 -0.02287695  0.16947652]
  [-0.1775206  -0.2801672  -0.0387468  -0.20264583  0.58125297]
  [ 0.39408762 -0.44066425  0.25826641 -0.18851604  0.36172166]
  [ 0.0536013  -0.29902928  0.08891931 -0.03930039  0.0743423 ]
  [ 0.02304702 -0.0612499   0.09113458 -0.05169013  0.29876455]
  [-0.06711324  0.014125   -0.05856332 -0.05632359 -0.00390189]]]
********************
(3, 5)
********************
[[-0.21596414 -0.63051687  0.20837501 -0.31775378  0.77519457]
 [ 0.23312444  0.156997    0.25482553  0.0138156  -0.02302272]
 [-0.06711324  0.014125   -0.05856332 -0.05632359 -0.00390189]]

总结一下:
tf.nn.dynamic_rnn 这个函数可以控制RNN cell 个数,构建适合业务场景需求的RNN 结构。

参考资料:
1、https://zhuanlan.zhihu.com/p/43041436
2、https://www.jianshu.com/p/9dc9f41f0b29

发布了114 篇原创文章 · 获赞 55 · 访问量 8万+

猜你喜欢

转载自blog.csdn.net/zuolixiangfisher/article/details/103489979