tensorflow学习之BasicRNNCell详解

版权声明:微信公众号:数据挖掘与机器学习进阶之路。本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013230189/article/details/82803217

1.循环神经网络

循环神经网络很像前馈神经网络,但是不同的是神经元有连接回指。

如上左图,一个循环神经元可以把自己的输出作为自身的输入,但是这个输入是上一个时间点的输出,如果将上面左图展开就变成右边的图:一个神经元在时间轴上的运行。

图右边的下标代表时间,循环神经元在时间 t 同时接受输入 x(t)和自己在上一时间 t−1的输出结果 y(t−1)

 

2.源码讲解

BasicRNNCell是抽象类RNNCell的一个最简单的实现。

class BasicRNNCell(RNNCell):

   # num_units:输出的神经元数量

   #activation:激活函数

    def __init__(self, num_units, activation=None, reuse=None):

        super(BasicRNNCell, self).__init__(_reuse=reuse)

        self._num_units = num_units

        self._activation = activation or math_ops.tanh

        self._linear = None



   #输出的隐藏状态神经元数量

    @property

    def state_size(self):

        return self._num_units



   #输出的神经元数量

    @property

    def output_size(self):

        return self._num_units

   #接受上一层的输出和隐藏状态神经元作为输入,返回该层的输出和隐藏状态

    def call(self, inputs, state):

        if self._linear is None:

            self._linear = _Linear([inputs, state], self._num_units, True)



        output = self._activation(self._linear([inputs, state]))

        return output, output

 

上面代码主要实现了下图:

公式为:

ht=tanh(Wk[xt,ht−1]+b)

从源代码里可以看到,state_size和output_size都跟num_units都是同一个数字,call函数返回两个一模一样的向量。

3.代码实例

import tensorflow as tf

import numpy as np



batch_size = 10 #批处理大小

input_dim = 100 #输入维度大小,如单词的词向量维度

output_dim = 128 #输出神经元数量



inputs = tf.placeholder(dtype=tf.float32, shape=(batch_size, input_dim))

previous_state = tf.random_normal(shape=(batch_size, output_dim))



cell = tf.contrib.rnn.BasicRNNCell(num_units=output_dim) #一个BasicRNNCell表示一个时间步

output, state = cell(inputs, previous_state) #output:输出神经元数量,state:隐藏神经元数量



X = np.ones(shape=(batch_size, input_dim))

print(output.shape) #(10, 128)

print(state.shape) #(10, 128)

# with tf.Session() as sess:

#     sess.run(tf.global_variables_initializer())

#     o, s = sess.run([output, state], feed_dict={inputs: X})



#     print(X)

#     print(previous_state.eval())

#     print(o)

#     print(s)

 

猜你喜欢

转载自blog.csdn.net/u013230189/article/details/82803217