用tensorflow构建动态RNN

直接看代码

def create_cell():
    cell = rnn.LSTMCell(num_units)
    return rnn.DropoutWrapper(cell, input_keep_prob=0.5)

rnn_cell = rnn.MultiRNNCell([create_cell() for _ in range(2)])
output, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)

相关API:

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

参数

cell:一种rnn 的cell,本实例中传入了一个多层的rnncell,每层cell的基本单元是LSTMCell,并且使用了dropout

inputs:输入数据

如果 time_major == False (default)
input的形状必须为 [batch_size, max_time, embed_size]

如果 time_major == True
input输入的形状必须为 [max_time, batch_size, embed_size]

其中batch_size是批大小,max_time是每个序列的大小,而embed_size是序列里面每个分量的大小


返回的是一个元组 (outputs, state)

outputs:RNN的最后一层的输出,是一个tensor
如果为time_major== False,则shape [batch_size,max_time,cell.output_size]。如果为time_major== True,则shape: [max_time,batch_size,cell.output_size]。cell.output_size就是num_units

state: RNN最后时间步的state,如果cell.state_size是一个整数(一般是单层的RNNCell),则state的shape:[batch_size,cell.state_size]。如果它是一个元组(一般这里是 多层的RNNCell),那么它将是一个具有相应形状的元组。注意:如果若RNNCell是 LSTMCells,则state将为每层cell的LSTMStateTuple的元组Tuple(LSTMStateTuple,LSTMStateTuple,LSTMStateTuple)
 

发布了39 篇原创文章 · 获赞 8 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/jancywen/article/details/88880153
今日推荐