tensorflow.nn.bidirectional_dynamic_rnn()函数的用法

版权声明:凡由本人原创,如有转载请注明出处https://me.csdn.net/qq_41424519,谢谢合作 https://blog.csdn.net/qq_41424519/article/details/82112904

这里写图片描述

开门见山来两张比较蛋疼的图,它们确实很流行。直奔主题。

def bidirectional_dynamic_rnn(
cell_fw, # 前向RNN
cell_bw, # 后向RNN
inputs, # 输入
sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
initial_state_fw=None,  # 前向的初始化状态(可选)
initial_state_bw=None,  # 后向的初始化状态(可选)
dtype=None, # 初始化和输出的数据类型(可选)
parallel_iterations=None,
swap_memory=False, 
time_major=False,
# 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`. 
# 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`. 
scope=None
)

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的二元组。假设 time_major=false, 而且tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。 
output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的二元组。 
output_state_fwoutput_state_bw的类型为LSTMStateTuple。 
LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

LSTM应用到双向RNN中

而cell_fw和cell_bw的定义是完全一样的。如果这两个cell选LSTM cell整个结构就是双向LSTM了。

# lstm模型正方向传播的RNN
lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)
# 反方向传播的RNN
lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)

但是看来看去,输入两个cell都是相同的啊? 
其实在bidirectional_dynamic_rnn函数的内部,会把反向传播的cell使用array_ops.reverse_sequence的函数将输入的序列逆序排列,使其可以达到反向传播的效果。 
在实现的时候,我们是需要传入两个cell作为参数就可以了:

(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, embedded_chars,  dtype=tf.float32)

embedded_chars为输入的tensor,[batch_szie, max_time, depth]。batch_size为模型当中batch的大小,应用在文本中时,max_time可以为句子的长度(一般以最长的句子为准,短句需要做padding),depth为输入句子词向量的维度。

代码实践:


import tensorflow as tf
import numpy as np

X = np.random.randn(2, 10, 8)
# The second example is of length 6
X[1, 6:] = 0
X_lengths = [9, 8]

cell = tf.nn.rnn_cell.LSTMCell(num_units=5, state_is_tuple=True)

outputs, states = tf.nn.bidirectional_dynamic_rnn(
    cell_fw=cell, cell_bw=cell, dtype=tf.float64, sequence_length=X_lengths, inputs=X
)

output_fw, output_bw = outputs
states_fw, states_bw = states

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    states_shape = tf.shape(states)
    print(states_shape.eval())
    c_f, h_f = states_fw
    o_f = output_fw
    c_b, h_b= states_bw
    o_b = output_bw
    print('c_f\n', sess.run(c_f))
    print('h_f\n', sess.run(h_f))
    print('o_f\n', sess.run(o_f))
    print('c_b\n', sess.run(c_b))
    print('h_b\n', sess.run(h_b))
    print('o_b\n', sess.run(o_b))
    

输出结果:

[2 2 2 5]
c_f
 [[-0.43276965 -0.34707254 -0.09180997  0.26827832  0.27571178]
 [ 0.27575224  0.15156946  0.12256522 -0.1233779  -0.09387333]]
h_f
 [[-0.29557532 -0.19821126 -0.02542468  0.1287899   0.10906331]
 [ 0.13909657  0.07485812  0.0607246  -0.06372124 -0.04719312]]
o_f
 [[[-0.08607912  0.19634355 -0.04141379 -0.09648713 -0.29296226]
  [ 0.0920274   0.12212318  0.06549744 -0.41358432 -0.02210931]
  [ 0.39993605 -0.03604745  0.38421408 -0.17096421  0.07381075]
  [ 0.17104686 -0.08531827  0.04249591  0.05365938  0.1784615 ]
  [ 0.00792906 -0.16713683 -0.02103182  0.07515517  0.06772459]
  [-0.20100924 -0.35576489  0.16194311  0.19446914  0.25483659]
  [-0.18140209 -0.08311345 -0.12816881  0.07098706  0.427926  ]
  [-0.17574083 -0.14505373 -0.23401455  0.15631583  0.39293472]
  [-0.29557532 -0.19821126 -0.02542468  0.1287899   0.10906331]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.09741206 -0.09779295  0.18918836  0.03278753  0.2577792 ]
  [ 0.02267391  0.06850602 -0.0155975  -0.23521581 -0.03577484]
  [ 0.19429619  0.06276382  0.10905737 -0.15550532 -0.01645063]
  [ 0.10287525  0.20157    -0.02434073 -0.11422428  0.00976497]
  [-0.05227936  0.32488201 -0.06576368 -0.11532339 -0.13688021]
  [ 0.22518737  0.10516309  0.12899814 -0.1449693  -0.00556297]
  [ 0.20323779  0.11170567  0.10008328 -0.08086347 -0.03259825]
  [ 0.13909657  0.07485812  0.0607246  -0.06372124 -0.04719312]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]]]
c_b
 [[ 0.09738836  0.19225204 -0.09284249 -0.47382426  0.00350991]
 [ 0.44997505  0.19447785  0.49119047 -0.44252046  0.31626763]]
h_b
 [[ 0.04020806  0.07441591 -0.03619023 -0.0777202   0.00210421]
 [ 0.13888461  0.09810557  0.15060079 -0.28964412  0.16514088]]
o_b
 [[[ 0.04020806  0.07441591 -0.03619023 -0.0777202   0.00210421]
  [ 0.30703784 -0.26512975  0.0314823   0.10928937  0.28692156]
  [ 0.10526645 -0.23850117  0.07682261  0.28263213  0.21087581]
  [-0.434327   -0.22145861 -0.21542902  0.3141704   0.31082225]
  [-0.28739246 -0.20374412 -0.02041121  0.15277031  0.22083064]
  [-0.39453516 -0.17825176  0.0045626   0.16392225  0.35356923]
  [-0.16832858 -0.00360075 -0.18095353  0.04436001  0.35192945]
  [-0.09912457 -0.12665507  0.00639166  0.12355956 -0.0580625 ]
  [-0.11977808 -0.08957523  0.07406649 -0.00428107 -0.11181204]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.13888461  0.09810557  0.15060079 -0.28964412  0.16514088]
  [ 0.07945991  0.16753371 -0.09477983 -0.27083062 -0.16861312]
  [ 0.17703591  0.10670111  0.05483377  0.00054334  0.03132806]
  [ 0.00802494  0.20236404 -0.12328111 -0.07817032 -0.00155747]
  [ 0.07895785  0.13487085  0.03472546 -0.04419926 -0.03887194]
  [ 0.26463148 -0.05714632  0.16954721  0.03967012  0.10644822]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]]]

猜你喜欢

转载自blog.csdn.net/qq_41424519/article/details/82112904