关于RNN相关模型-tensorflow源码理解

本文主要是对tensorflow中lstm模型中的c,h进行解析。rnn_cell_impl.py

1.关于RNN模型

在rnn_cell_impl.py的tensorflow源码中,关于RNN部分实现的类主要是BasicRNNCell,
首先在build函数中,定义了两个变量_kernel和_bias。
这里写图片描述
其中_num_untis表示RNN cell 的untis数目。
所以在call函数中,hidden_state的更新如下所示:
这里写图片描述

从上面中可以看出,RNN首先将input与上一个state连接,然后与在build函数中定义的_kernal变量点乘,最后加上偏置项。

2. 关于LSTM模型

主要看BasicLSTMCell这个类,在build函数中,定义了两个参数_kernel与_bias
这里写图片描述
与RNN不同,参数_kernal与_bias的列都是_num_units的四倍,主要是因为后面要分成四个部分,分别为i,j,f,o。
因此,在call函数中,
这里写图片描述
在call函数中,i,j,f,o可以分别表示为:
这里写图片描述

所以,在上面的图中,最上面的横线表示C,最小面的横线表示h。

3. 关于GRU模型

在GRU模型中的build函数中,可以看到定义了四个参数:
这里写图片描述
因此,在call函数中,

这里写图片描述
从下面的图中可以看出,zt为u,r表示rt,
这里写图片描述

从tensorflow的源码来看,上面的公式中ht的求解有问题,所以参考维基百科,得到下面的公式:
这里写图片描述

发布了98 篇原创文章 · 获赞 337 · 访问量 48万+

猜你喜欢

转载自blog.csdn.net/yiyele/article/details/81987919