代码:
import tensorflow as tf
import numpy as np
def get_a_cell():
### 128 是 状态矢量的长度
return tf.nn.rnn_cell.BasicRNNCell(num_units=128)
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])
print(cell.state_size)
## 32 是 batch_size ,100 是 inputs 矢量的长度
inputs = tf.placeholder(np.float32,shape=(32,100))
h0 = cell.zero_state(32,np.float32) ## 通过zero_state得到一个全0的初始状态(只需给出状态的矢量长度即可,因为状态肯定是矢量)
output,h1 = cell(inputs,h0)
print(output)
print(h1)
输出:
(128, 128, 128)
Tensor("multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0", shape=(32, 128), dtype=float32)
(
<tf.Tensor 'multi_rnn_cell/cell_0/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>,
<tf.Tensor 'multi_rnn_cell/cell_1/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>,
<tf.Tensor 'multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0' shape=(32, 128) dtype=float32>
)