RNN神经网络、LSTM神经网络、LSTM的变体:GRU神经网络、Tensorflow搭建第一个RNN——LSTM神经网络(分类)

版权声明:站在巨人的肩膀上学习。 https://blog.csdn.net/zgcr654321/article/details/84112554

RNN(Recurrent Neural Network)循环神经网络:

还记得一般的神经网络的结构吗?一般是下面这样的:

对于一般的神经网络,我们总是做这样的前提假设:元素之间是相互独立的,输入与输出也是独立的。 

但是,有些时候我们的输入数据是有顺序的,即前面的输入和后面的输入是有关系的。比如说一段话或者一段视频(可能是不定长度的)。这个时候我们就要使用RNN(Recurrent Neural Network)循环神经网络了。

我们先来看一个最简单的RNN网络,它由输入层、一个隐藏层和一个输出层组成:

这个网络只要去掉w部分就是一个普通的全连接神经网络。

其中:

x是一个输入向量;s是一个隐藏层的输出向量;U是输入层到隐藏层的权重矩阵,O是输出层的输出向量,V是隐藏层到输出层的权重矩阵。

循环神经网络的隐藏层的值s不仅仅取决于当前这次的输入x,还取决于上一次隐藏层的值s。权重矩阵W就是隐藏层上一次的值作为这一次的输入的权重。

RNN循环神经网络的一般性结构:

其中Xt表示t时刻的输入,Ot表示t时刻的输出,St表示t时刻的记忆,f1和f2代表不同层上的激活函数。

虽然说RNN处理的是不定长输入数据,但是某个时刻的输入还是定长的。RNN和CNN有着同样的共享权值的属性,所以不同时刻的U,V,W都是相同的,所有整个网络的学习目标就是优化这些参数以及偏置。

以时间为轴展开的RNN网络结构如下:

注意:

在实际运行中,Xt-1、Xt和Xt+1用的都是同一批神经元,只是它们输入的数据的输入时刻不一样(你可以把它们理解为数据是一批一批输入的,然后每一批都有一个输出,不过一般我们最后只用到最后一个时刻的输出,即final_state中的c输出)。

双向RNN循环神经网络:

上面的RNN网络是单向的,即总是会记住前面时刻的一部分信息,但没有记录后面时刻的信息。双向RNN网络就是为了同时记录前面时刻和后面时刻的信息而发明的。其结构如下:

双向RNN网络可以同时正向和反向读取序列,上面和下面的绿色方框都代表隐藏层,这时上方的隐藏层记录后面时刻的信息,而下方的隐藏层记录前面时刻的信息,在输出时,将一上一下两个隐藏层的信息合在一起再处理后输出。

举例:

以上图y1为例。

抽象到一般情况下的计算公式为:

注意:

这里计算的两个W权重矩阵是不同的,因为正向和反向是两批不同的神经元。

上面的单向RNN和双向RNN中的隐藏层都可以由单个扩展成多个,计算方式大同小异,不再赘述。

LSTM(Long Short-Term Memory)长短期记忆网络:

LSTM,全称为长短期记忆网络(Long Short Term Memory networks),是一种特殊的RNN,能够学习到长期依赖关系。

LSTM区别于普通RNN的地方在于它在算法中加入了一个判断信息有用与否的“处理器”,这个处理器作用的结构被称为cell,即门控单元。一个cell当中被放置了三扇门,分别叫做输入门、遗忘门和输出门。

一个信息进入LSTM的网络当中,可以根据规则来判断是否有用。只有符合算法认证的信息才会留下,不符的信息则通过遗忘门被遗忘。

LSTM网络的单元结构:

单元状态(cell state):

即图中LSTM单元上方从左贯穿到右的水平线,它是LSTM的关键,像传送带一样,将信息从上一个单元传递到下一个单元,和其他部分只有很少的线性的相互作用。 

LSTM通过“门”(gate)来控制丢弃或者增加信息,从而实现遗忘或记忆的功能。“门”是一种使信息选择性通过的结构,由一个sigmoid函数和一个点乘操作组成。sigmoid函数的输出值在[0,1]区间,0代表完全丢弃,1代表完全通过。

一个LSTM单元有三个这样的门,分别是遗忘门(forget gate)、输入门(input gate)、输出门(output gate)。

遗忘门(forget gate):

遗忘门是以上一单元的输出ht−1和本单元的输入xt为输入的sigmoid函数,为Ct−1中的每一项产生一个在[0,1]内的值,来控制上一单元状态被遗忘的程度。1表示“完全保留”,0 表示“完全舍弃”。

公式如下:

输入门(input gate):

输入门和一个tanh函数配合控制有哪些新信息被加入。tanh函数产生一个新的候选向量,输入门为中的每一项产生一个在[0,1]内的值,控制新信息被加入的多少。

公式如下:

现在我们已经有了遗忘门的输出ft,用来控制上一单元被遗忘的程度,也有了输入门的输出it,用来控制新信息被加入的多少,我们就可以更新本记忆单元的单元状态了。

更新本单元的记忆状态:

旧状态与ft相乘,丢弃掉我们确定需要丢弃的信息。接着加上it* 这就是新的记忆状态。

输出门(output gate):

输出门用来控制当前的单元状态有多少被过滤掉。先将单元状态激活,输出门为其中每一项产生一个在[0,1]内的值,控制单元状态被过滤的程度。 接着,我们把更新后的记忆状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。

公式如下:

回到最初的单元状态:

输入的Ct-1先经过遗忘门,根据遗忘门的值来决定保留多少百分比的Ct-1的数据值;然后数据经过输入门,根据输入门的运算结果来加入一部分本单元更新的数据值,这两步合并起来即为:

然后Ct数据经过输出门,根据输出门Ot的值来处理Ct的值。同时也记录ht。

LSTM的变体:GRU(Gated Recurrent Unit)神经网络

回顾一下LSTM的模型,LSTM实现了三个门计算,即遗忘门、输入门和输出门。

GRU模型只有两个门,分别为更新门和重置门,即将忘记门和输入门合成了一个单一的更新门。更新门和重置门即下图中的zt和rt。

更新门用于控制前一时刻的状态信息被带入到当前状态中的程度,更新门的值越大说明前一时刻的状态信息带入越多;

重置门用于控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。

计算公式如下:

Tensorflow搭建第一个RNN——LSTM神经网络(分类):

我们使用Tensorflow来搭建一个RNN——LSTM分类神经网络。使用的数据集仍然是MNIST手写数字数据集。我们让神经网络从每张图片的第一行像素读到最后一行,然后进行分类判断。

代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

tf.set_random_seed(1)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 超参数
lr = 0.0001
training_iteration = 2000
train_batch_size = 100
test_batch_size = 100
# MNIST的图片是28X28像素,我们把每一行的像素点看成一个时间间隔内输入的数据(28个像素点),一张图片是28个时间间隔
# 每个时间间隔内输入一行的28个列上的元素
inputs_cols = 28
# 有多少行就是有多少个时间间隔,即步数
steps_rows = 28
# 隐藏层神经元数量
n_hidden_units = 128
# 10种分类
mnist_classes = 10

# X和Y占位符
X = tf.placeholder(tf.float32, [None, steps_rows, inputs_cols])
Y = tf.placeholder(tf.float32, [None, mnist_classes])


def weight_variable(shape):
	inital = tf.truncated_normal(shape, stddev=0.1)
	return tf.Variable(inital)


def bias_variable(shape):
	initial = tf.constant(0.1, shape=shape)
	return tf.Variable(initial)


def RNN_LSTM(x, stps, one_stp_inputs, hid_units, classes, bt_size):
	# x reshape成(128batches*28steps,28inputs)
	x = tf.reshape(x, [-1, one_stp_inputs])

	w_in = weight_variable([one_stp_inputs, hid_units])
	w_out = weight_variable([hid_units, classes])

	b_in = bias_variable([hid_units, ])
	b_out = bias_variable([classes, ])

	x_in = tf.matmul(x, w_in) + b_in
	# X_in reshape成(128batches,28steps,128hidden)
	x_in = tf.reshape(x_in, [-1, stps, n_hidden_units])

	# tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True):
	# n_hidden表示LSTM cell层中神经元的个数,forget_bias就是LSTM们的忘记系数,如果等于1,就是不会忘记任何信息。如果等于0,就都忘记。
	# state_is_tuple默认就是True,官方建议用True,就是表示返回的状态用一个元祖表示。
	lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
	# 状态初始化函数zero_state(batch_size,dtype):batch_size就是输入样本批次的数目,dtype就是数据类型。
	# state初始化全零
	init_state = lstm_cell.zero_state(bt_size, dtype=tf.float32)
	# time_major如果是True,就表示RNN的steps用第一个维度表示,如果是False,那么输入的第二个维度就是steps。
	# 如果是True,output的维度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入维度一样。
    # 为True时,我们是以LSTM为tf.nn.dynamic_rnn的输入cell类型,此时state形状为[2,batch_size, cell.output_size];
	# 为False时,以GRU为tf.nn.dynamic_rnn的输入cell类型,此时state形状为[batch_size, cell.output_size]
    # 如果cell选择了LSTM,那final_state是个tuple,分别代表Ct和ht,其中ht与outputs中的对应的最后一个时刻的输出相等;
	# 假设final_state形状为[2,batch_size, cell.output_size],outputs形状为[batch_size, max_time, cell.output_size]
	# 那么final_state[1,batch_size,:] == outputs[batch_size, -1,:]
	# 如果cell是GRU,final_state其实就是ht,final_state==outputs[-1]
	# 其原因是因为LSTM和GRU的结构本身不同,GRU将遗忘门和输入门合并成了更新门
	# 当cell为GRU时,state就只有一个了,原因是GRU将Ct和ht进行了简化,将其合并成了ht
	# final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden]
	outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, x_in, initial_state=init_state, time_major=False)
	# output里面,包含了所有时刻的输出h,state里面,包含了最后一个时刻的输出c和所有时刻的输出h;
	# 如果只需要最后一个时刻的状态输出,直接使用state里面的h输出就可以了。即final_state[1]
	# c和h的维度都是[batch_size, n_hidden]
	results = tf.matmul(final_state[1], w_out) + b_out
	return results


y_pred = RNN_LSTM(X, steps_rows, inputs_cols, n_hidden_units, mnist_classes, train_batch_size)
# 输出数据仍然用softmax函数处理
# cost函数仍然使用交叉熵函数
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y, logits=y_pred))
# 优化器使用Adam优化算法
optimizer = tf.train.AdamOptimizer(lr).minimize(cost)
# 得到的输出数据是标签矩阵,shape=(batchsize,10),因此我们取第一个维度上的最大值下标,然后看预测和真实的标签是否相等,返回布尔值
correct_pred = tf.equal(tf.argmax(y_pred, 1), tf.argmax(Y, 1))
# 如果是布尔值,可将其转为0和1的float值序列
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	for i in range(training_iteration):
		train_batch_xs, train_batch_ys = mnist.train.next_batch(train_batch_size)
		train_batch_xs = train_batch_xs.reshape([train_batch_size, steps_rows, inputs_cols])
		_, loss = sess.run([optimizer, cost], feed_dict={X: train_batch_xs, Y: train_batch_ys})
		if i % 20 == 0:
			test_batch_xs, test_batch_ys = mnist.train.next_batch(test_batch_size)
			test_batch_xs = test_batch_xs.reshape([test_batch_size, steps_rows, inputs_cols])
			acc = sess.run(accuracy, feed_dict={X: test_batch_xs, Y: test_batch_ys})
			print("iteration:{} loss:{} acc:{}".format(i, loss, acc))

运行结果如下:

iteration:0 loss:2.3076298236846924 acc:0.11999999731779099
iteration:20 loss:2.274477958679199 acc:0.23000000417232513
iteration:40 loss:2.2652275562286377 acc:0.3100000023841858
iteration:60 loss:2.1544978618621826 acc:0.3499999940395355
iteration:80 loss:2.0182301998138428 acc:0.46000000834465027
iteration:100 loss:1.851188063621521 acc:0.41999998688697815
iteration:120 loss:1.7927887439727783 acc:0.3700000047683716
iteration:140 loss:1.706325650215149 acc:0.44999998807907104
iteration:160 loss:1.526696801185608 acc:0.5799999833106995
iteration:180 loss:1.505893588066101 acc:0.6000000238418579
iteration:200 loss:1.373097538948059 acc:0.6399999856948853
iteration:220 loss:1.3457021713256836 acc:0.6200000047683716
iteration:240 loss:1.1731350421905518 acc:0.699999988079071
iteration:260 loss:0.945415735244751 acc:0.6499999761581421
iteration:280 loss:1.0089123249053955 acc:0.6800000071525574
iteration:300 loss:0.8178996443748474 acc:0.699999988079071
iteration:320 loss:0.9596269130706787 acc:0.7699999809265137
iteration:340 loss:0.8439638018608093 acc:0.75
iteration:360 loss:0.7150177955627441 acc:0.7400000095367432
iteration:380 loss:0.6129633188247681 acc:0.8700000047683716
iteration:400 loss:0.6865827441215515 acc:0.7599999904632568
iteration:420 loss:0.7067865133285522 acc:0.8399999737739563
iteration:440 loss:0.581174910068512 acc:0.8100000023841858
iteration:460 loss:0.638022243976593 acc:0.7799999713897705
iteration:480 loss:0.5141828656196594 acc:0.8799999952316284
iteration:500 loss:0.6109911799430847 acc:0.8700000047683716
iteration:520 loss:0.5281422138214111 acc:0.8299999833106995
iteration:540 loss:0.36358845233917236 acc:0.8899999856948853
iteration:560 loss:0.32377177476882935 acc:0.8600000143051147
iteration:580 loss:0.5583460927009583 acc:0.8999999761581421
iteration:600 loss:0.3640763461589813 acc:0.9100000262260437
iteration:620 loss:0.3980554938316345 acc:0.8700000047683716
iteration:640 loss:0.4693172574043274 acc:0.8899999856948853
iteration:660 loss:0.4157080054283142 acc:0.8799999952316284
iteration:680 loss:0.34606510400772095 acc:0.9200000166893005
iteration:700 loss:0.37909436225891113 acc:0.8899999856948853
iteration:720 loss:0.2907980680465698 acc:0.9300000071525574
iteration:740 loss:0.3360324501991272 acc:0.9300000071525574
iteration:760 loss:0.31698155403137207 acc:0.8799999952316284
iteration:780 loss:0.4803481996059418 acc:0.9200000166893005
iteration:800 loss:0.4487611651420593 acc:0.8299999833106995
iteration:820 loss:0.23545436561107635 acc:0.8799999952316284
iteration:840 loss:0.408072829246521 acc:0.8899999856948853
iteration:860 loss:0.3766385614871979 acc:0.8899999856948853
iteration:880 loss:0.23924116790294647 acc:0.8799999952316284
iteration:900 loss:0.3980361521244049 acc:0.8899999856948853
iteration:920 loss:0.30980074405670166 acc:0.9200000166893005
iteration:940 loss:0.3675287365913391 acc:0.9200000166893005
iteration:960 loss:0.33362916111946106 acc:0.9200000166893005
iteration:980 loss:0.1964966058731079 acc:0.949999988079071
iteration:1000 loss:0.2528787851333618 acc:0.949999988079071
iteration:1020 loss:0.28432559967041016 acc:0.9399999976158142
iteration:1040 loss:0.18621350824832916 acc:0.8899999856948853
iteration:1060 loss:0.28590288758277893 acc:0.8999999761581421
iteration:1080 loss:0.2063649296760559 acc:0.949999988079071
iteration:1100 loss:0.3142355680465698 acc:0.9599999785423279
iteration:1120 loss:0.23007726669311523 acc:0.9100000262260437
iteration:1140 loss:0.1843377947807312 acc:0.949999988079071
iteration:1160 loss:0.17146600782871246 acc:0.9200000166893005
iteration:1180 loss:0.2453075349330902 acc:0.9399999976158142
iteration:1200 loss:0.19952549040317535 acc:0.9300000071525574
iteration:1220 loss:0.21484214067459106 acc:0.9100000262260437
iteration:1240 loss:0.2458951622247696 acc:0.9599999785423279
iteration:1260 loss:0.24367006123065948 acc:0.9300000071525574
iteration:1280 loss:0.27452895045280457 acc:0.9100000262260437
iteration:1300 loss:0.32453295588493347 acc:0.9599999785423279
iteration:1320 loss:0.19615527987480164 acc:0.949999988079071
iteration:1340 loss:0.16604425013065338 acc:0.949999988079071
iteration:1360 loss:0.28147247433662415 acc:0.9399999976158142
iteration:1380 loss:0.28666773438453674 acc:0.9599999785423279
iteration:1400 loss:0.30330950021743774 acc:0.9399999976158142
iteration:1420 loss:0.14046263694763184 acc:0.949999988079071
iteration:1440 loss:0.26582375168800354 acc:0.949999988079071
iteration:1460 loss:0.15203167498111725 acc:0.9399999976158142
iteration:1480 loss:0.199817955493927 acc:0.949999988079071
iteration:1500 loss:0.23183047771453857 acc:0.9599999785423279
iteration:1520 loss:0.24718023836612701 acc:0.9300000071525574
iteration:1540 loss:0.21721810102462769 acc:0.9399999976158142
iteration:1560 loss:0.18312248587608337 acc:0.9300000071525574
iteration:1580 loss:0.09985671937465668 acc:0.9399999976158142
iteration:1600 loss:0.31880244612693787 acc:0.9399999976158142
iteration:1620 loss:0.11516935378313065 acc:0.9300000071525574
iteration:1640 loss:0.12962393462657928 acc:0.9700000286102295
iteration:1660 loss:0.14879734814167023 acc:0.9700000286102295
iteration:1680 loss:0.1875951886177063 acc:0.9100000262260437
iteration:1700 loss:0.1656365990638733 acc:0.9700000286102295
iteration:1720 loss:0.12901760637760162 acc:0.949999988079071
iteration:1740 loss:0.19696255028247833 acc:0.9599999785423279
iteration:1760 loss:0.3279583752155304 acc:0.9399999976158142
iteration:1780 loss:0.22977523505687714 acc:0.9399999976158142
iteration:1800 loss:0.2246321737766266 acc:0.9700000286102295
iteration:1820 loss:0.15428473055362701 acc:0.9399999976158142
iteration:1840 loss:0.09635346382856369 acc:0.9300000071525574
iteration:1860 loss:0.1093805581331253 acc:0.9800000190734863
iteration:1880 loss:0.1567453294992447 acc:0.9100000262260437
iteration:1900 loss:0.13261663913726807 acc:0.9399999976158142
iteration:1920 loss:0.1055414006114006 acc:0.9399999976158142
iteration:1940 loss:0.12914611399173737 acc:0.9700000286102295
iteration:1960 loss:0.159665048122406 acc:0.9100000262260437
iteration:1980 loss:0.31085625290870667 acc:0.9300000071525574

Process finished with exit code 0

猜你喜欢

转载自blog.csdn.net/zgcr654321/article/details/84112554