多层RNN的定义与理解

代码:


import tensorflow as tf
import numpy as np

def get_a_cell():
    ### 128 是 状态矢量的长度
    return tf.nn.rnn_cell.BasicRNNCell(num_units=128)
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

print(cell.state_size)

## 32 是 batch_size ,100 是 inputs 矢量的长度
inputs = tf.placeholder(np.float32,shape=(32,100))
h0 = cell.zero_state(32,np.float32) ## 通过zero_state得到一个全0的初始状态(只需给出状态的矢量长度即可,因为状态肯定是矢量)

output,h1 = cell(inputs,h0)
print(output)
print(h1)

输出:

(128, 128, 128)
Tensor("multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0", shape=(32, 128), dtype=float32)
(
<tf.Tensor 'multi_rnn_cell/cell_0/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>, 
<tf.Tensor 'multi_rnn_cell/cell_1/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>, 
<tf.Tensor 'multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>
)


 

 

猜你喜欢

转载自blog.csdn.net/Strive_For_Future/article/details/82025938