mnist LSTM 训练、测试,模型保存、加载和识别

原创文章,转载请注明出处:http://blog.csdn.net/wanggao_1990/article/details/77964504

MNIST 字符数据库每个字符(0-9) 对应一张28x28的一通道图片,可以将图片的每一列(行)当作特征,所有行(列)当做一个序列。那么可以通过输入大小为28,时间长度为28的RNN(lstm)对字符建模。对于同一个字符,比如0,其行与行之间的动态变化可以很好地被RNN表示,所有这些连续行的变化表征了某个字符的特定模式。因此可以使用RNN来进行字符识别。

Tensorflow提供了不错的RNN接口,基本思路是
1. 建立RNN网络中的基本单元 cell; tf提供了很多中类型的cell, BasicRNNCell,BasicLSTMCell,LSTMCell 等等
2. 通过调用rnn.static_rnn 函数或者rnn.static_bidirectional_rnn将cell连成RNN 网络。本例子采用的是rnn.static_bidirectional_rnn函数。(版本不同有所区别)

LSTM训练、测试

import os
import numpy as np
'''
A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library.
This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)
Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf

Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
'''

from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn

# Import user date convertor
import os
from convert_data import convert_datas

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True)

'''
To classify images using a bidirectional recurrent neural network, we consider
every image row as a sequence of pixels. Because MNIST image shape is 28*28px,
we will then handle 28 sequences of 28 steps for every sample.
'''

# Parameters
learning_rate = 0.001

# 训练迭代次数
training_iters = 100000

# 每次训练的样本大小
batch_size = 128

# 这个是用来显示的。
display_step = 10

# Network Parameters
# n_steps*n_input其实就是那张图 把每一行拆到每个time step上。
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps


# 隐藏层大小
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)

# tf Graph input
# [None, n_steps, n_input]这个None表示这一维(样本数)不确定大小
x = tf.placeholder("float", [None, n_steps, n_input], name="input_x")
y = tf.placeholder("float", [None, n_classes], name="input_y")

# Define weights and biases
weights = tf.Variable(tf.random_normal([2*n_hidden, n_classes]), name="weights")
biases = tf.Variable(tf.random_normal([n_classes]), name="biases")

def BiRNN( x, weights, biases):
    # Prepare data shape to match `bidirectional_rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    # 变成了n_steps*(batch_size, n_input)
    x = tf.unstack(x, n_steps, 1)

    # Define lstm cells with tensorflow
    # Forward direction cell
    lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # Backward direction cell
    lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)

    # Get lstm cell output
    try:
        outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)
    except Exception: # Old TensorFlow version only returns outputs not states
        outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output

    # return tf.matmul(outputs[-1], weights['out']) + biases['out']
    # return tf.matmul(outputs[-1], weights) + biases

    return tf.add(tf.matmul(outputs[-1], weights), biases)

pred = BiRNN(x, weights, biases)

# Define loss and optimizer
# softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive
# return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss.
# reduce_mean就是对所有数值(这里没有指定哪一维)求均值。
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf.global_variables_initializer()

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    step = 1
    # Keep training until reach max iterations
    while step * batch_size < training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, n_steps, n_input))
        # Run optimization op (backprop)
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        if step % display_step == 0:
            # Calculate batch accuracy
            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
            # Calculate batch loss
            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
            print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + \
                  ", Training Accuracy= " + "{:.5f}".format(acc))
        step += 1
    print("Optimization Finished!")

    # Calculate accuracy for 128 mnist test images
    # test_len = 128
    # test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
    # test_label = mnist.test.labels[:test_len]

    ## Input 为 batch_size*30*17
    ##  实际测试,需要满足 tensorflow的输入placeholder要求
    test_data = mnist.test.images.reshape((-1, n_steps, n_input))
    test_label = mnist.test.labels

    print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

保存训练模型

紧接着上面测试进度输出后,输入以下代码 重复运行即可

    saver = tf.train.Saver()
    model_path = "./model/my_model"
    save_path = saver.save(sess, model_path)
    print("Model saved in file: %s" % save_path)

这里只是一种方式,并且保存整个网络结构。 model_path中的model是模型保存的文件夹,my_model是保存模型的前缀,可以理解为模型的名称。

运行完毕后,当前目录会新建名称为“model”的文件夹,且含有四个文件夹:checkpoint、my_model.data-00000-of-00001、my_model.index和my_model.meta。这里的四个文件的有关介绍网上有很多。

注意,这里值是进行了模型的保存,这里保存的目的是为了进行加载并对输入的数据进行测试,并且不需要重建整个网络。因此,还需要对某些计算节点进行保存,在识别阶段利用这些节点计算输出。这里需要增加1个预测节点。在pred = BiRNN(x, weights, biases)后增加:

    tf.add_to_collection('predict', pred)

将pred整个计算和“predict”整个名字绑定在一起,就可以在加载后通过整个名字读取整个运算节点。


加载训练模型 、识别

加载模型很简单,主要代码如下

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./model/my_model.meta')
    new_saver.restore(sess, './model/my_model')

这里需要注意,restore()函数的路径和保存时要一致。

接着,从加载的模型中读取需要的节点。首先是predict节点对应的pred运算,其次这个pred运算需要输入x,也就是训练代码中的占位符“input_x”。继续添加代码如下

    graph = tf.get_default_graph()    
    predict = tf.get_collection('predict')[0]
    input_x = graph.get_operation_by_name("input_x").outputs[0]

最后,就是输入一个图片数据,对其进行识别分类了。

    x = mnist.test.images[0].reshape((1, n_steps, n_input))
    res = sess.run(predict, feed_dict={input_x: x})

这里用的test数据集的第一个图,这里的过程和测试部分类似,只是没有第二个参数label。返回的结果可以通过tf.argmax进行获取类别值。

在利用argmax函数时,需要确认数据的shape,再确定计算的维度。这一部分完整代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True)

n_input = 28
n_steps = 30
n_classes = 2

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./model/my_model.meta')
    new_saver.restore(sess, './model/my_model')

    graph = tf.get_default_graph()
    predict = tf.get_collection('predict')[0]
    input_x = graph.get_operation_by_name("input_x").outputs[0]

    x = mnist.test.images[0].reshape((1, n_steps, n_input))
    y = mnist.test.labels[0].reshape(-1, n_classes)  # 转为one-hot形式

    res = sess.run(predict, feed_dict={input_x: test_data })

    print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \
          ", predict class ",str(sess.run(tf.argmax(res, 1))), \
          ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(res, 1))))
          )

猜你喜欢

转载自blog.csdn.net/wanggao_1990/article/details/77964504