【tensorflow】数字识别 — rnn 算法


  在数字识别- softmax回归文章中使用softmax回归算法对图片进行分类,准确率在92%左右,那么如何使算法准确率得到提升了?本篇文章将使用 rnn 神经网络算法进行数字识别。

import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data

# 获取 mnist 数据
mnist = input_data.read_data_sets('data/mnist', one_hot=True)

# 图片size 为28*28,以一列为一个 time_step, 所以time_step_size=28,input_size = 28
time_step_size = 28
input_size = 28
hidden_size = 256
layer_size = 2

_X = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
X = tf.reshape(_X, [-1, time_step_size, input_size])
keep_prob = tf.placeholder(tf.float32)
batch_size = tf.placeholder(tf.int32, [])

# 定义 lstm 单元
def lstm_cell():
    lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, reuse=tf.get_variable_scope().reuse)
    return rnn.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob)


mlstm_cell = rnn.MultiRNNCell([lstm_cell() for _ in range(layer_size)], state_is_tuple=True)

init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)

outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)

# 取最后一个 time_step的输出作为softmax层的输入
h_state = outputs[:, -1, :]

W = tf.Variable(tf.truncated_normal([hidden_size, 10], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[10]))
prediction = tf.nn.softmax(tf.matmul(h_state, W) + b)

# 计算损失
cross_entropy = -tf.reduce_sum(y * tf.log(prediction))

# Adam 进行模型优化
train_op = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)

# 正确预测数目
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))

# 计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(2001):
    xs, ys = mnist.train.next_batch(100)
    if i % 200 == 0:
        train_accuracy = sess.run(accuracy, feed_dict={_X: xs, y: ys, keep_prob: 1.0, batch_size: 100})
        print("Iter%d, step %d, training accuracy %g" % (mnist.train.epochs_completed, i, train_accuracy))
    sess.run(train_op, feed_dict={_X: xs, y: ys, keep_prob: 0.5, batch_size: 100})
print("test accuracy %g" % sess.run(accuracy, feed_dict={_X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0,batch_size: mnist.test.labels.shape[0]}))

  正确率为98.31%

猜你喜欢

转载自blog.csdn.net/lionel_fengj/article/details/80487991