LSTM-tf.nn.static_rnn与tf.nn.dynamic_rnn.用法详解

最近研究LSTM的网络,想将LSTM应用到图像上,查资料发现,用到图像上的LSTM叫ConvLSTM,在这里记录下最核心的两个函数用法:
tf.nn.static_rnn与tf.nn.dynamic_rnn.

这两个函数是tensoflow针对RNN的LSTM提供的两个函数,两个函数的功能上其实差不多,但是tf.nn.dynamic()函数更加灵活.这里我还主要讲解函数用法没和两个函数输入数据形式的区别:

1.tf.nn.dynamic_rnn

def dynamic_rnn(cell, inputs, sequence_length=None, 
        initial_state=None,dtype=None,
        parallel_iterations=None,swap_memory=False,time_major=False, scope=None):
return output,state

parameter:
cell:参数:cell,自己定义的LSTM的细胞单元,如果是convLSTM,自己写也可以,.
下面两个链接提供cell函数(都可以用):https://github.com/TakuyaShinmura/conv_lstm/blob/master/conv_lstm_cell.py
https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py

也可以用tensorflow自带的API:tf.contrib.rnn.ConvLSTMCell()

inputs:一个5维的变量,[batchsize,timestep,image.shape],搭配time_major=False.这里还补充一点,就是叫dynamic的原因,就是输入数据的time_step不一定要相同,如果长短不一,会自动跟短的补0,但是处理时候,不会处理0,在0前面就截止了.这就是dynamic对比static的好处.

time_major: If true, these Tensors must be shaped [max_time, batch_size, depth].
If false, these Tensors must be shaped `[batch_size, max_time, depth]
其实很好理解,如果是true,就是time_step是主导,最前面就是max_time,如果是false,batch_size占主导,batch_size在前面,就是我上面的5维变量输入形式.
其他的参数都可以不用设置,默认就行.

最后所以下函数返回值:这里output是每个cell输出的叠加,比如我输入数据[1,5,100,100,3],是一个长度为5 的视频序列,则返回output为[1,5,100,100,3],5个cell细胞的输出状态,state是一个元组类型的数据,有(c和h两个变量)就是存储LSTM最后一个cell的输出状态,我一般用的是output的最后一个输出..用state输出也行,就是取元组中的h变量.
用下列语句输出:

 outputs = tf.transpose(outputs,[1,0,2,3,4])#这一步必不可少,将max_time提前,后面的output[-1]才是最后一个time的输出,也就是最后一个cell的输出
 last_output=outputs[-1]
然后在处理last_output
'''''''
处理过程
'''''''
return result

2.tf.nn.static_rnn()

def static_rnn(cell, inputs,initial_state=None, dtype=None, 
         sequence_length=None, scope=None)

return outputs,state

参数:cell,自己定义的LSTM的细胞单元,如果是convLSTM,自己写也可以,.
下面两个链接提供cell函数(都可以用):https://github.com/TakuyaShinmura/conv_lstm/blob/master/conv_lstm_cell.py
https://github.com/carlthome/tensorflow-convlstm-cell/blob/master/cell.py

也可以用tensorflow自带的API:tf.contrib.rnn.ConvLSTMCell()

参数:input,和上面的tf.nn.dynamic_rnn有很大不同,这里的input输入是一个List,记住,是list,也就是输入是一个[ ],里面每个List元素都是一组图片,比如[iamges1,images2,images3],images 是有多张图片的一个图片序列.

sequence_length:序列长度,可以不设置,因为input里面可以指定.

最后是返回值:output也是一个list,list的每个元素对应每个image1,image2…的输入,取最后一个list元素,也就是outputs[-1],就是最后一个cell的输出.state和上面一样,记录最后一个cell的状态.

last_output=outputs[-1]
然后在处理last_output
'''''''
处理过程
'''''''
return result

最后这个链接,直接用代码总结,也很好:https://manutdzou.github.io/2017/11/27/tensorflow-lstm.html

最后是一个torch版本的综合应用,弄懂有点困难:https://github.com/viorik/ConvLSTM

好了,相信已经明白怎么使用这个ConLSTM结构了,本人也是初学,有不对的地方多多留言交流.

猜你喜欢

转载自blog.csdn.net/CV_YOU/article/details/81164393