(系列更新完毕)深度学习零基础使用 TensorFlow 框架跑 MNIST 数据集的第三天:测试模型

1. Introduction

今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第二天,主要学习测试网络。本 blog 主要记录一个学习的路径以及学习资料的汇总。

注意:这是用 Python 2.7 版本写的代码

第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108245969

第二天(训练网络):https://blog.csdn.net/qq_36627158/article/details/108315239

第三天(测试网络):https://blog.csdn.net/qq_36627158/article/details/108321673

第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108397018

2. Code(mnist_test.py)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet



def test_model(test_dataset):
    num_of_test_data = test_dataset.images.shape[0]

    images_holder = tf.placeholder(
        dtype=tf.float32,
        shape=[num_of_test_data, 28, 28, 1]
    )
    labels_holder = tf.placeholder(
        dtype=tf.float32,
        shape=[num_of_test_data, 10]
    )

    test_images = test_dataset.images
    test_labels = test_dataset.labels
    test_images_reshaped = tf.reshape(
        tensor=test_images,
        shape=[num_of_test_data, 28, 28, 1]
    )

    label_predict = mnist_lenet.build_model_and_forward(images_holder)

    correct_prediction = tf.equal(
        tf.argmax(test_labels, 1),
        tf.argmax(label_predict, 1)
    )
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    saver = tf.train.Saver()

    with tf.Session() as sess:
        test_feed = {
            images_holder: test_images_reshaped.eval(),
            labels_holder: test_labels
        }

        ckpt = tf.train.get_checkpoint_state("./models")

        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            accuracy_score = sess.run(accuracy, feed_dict=test_feed)
            print "After training the model, the test accuracy =" , \
                accuracy_score * 100, "%"
        else:
            print("No checkpoint file found")
            return


if __name__ == '__main__':
    mnist_data = input_data.read_data_sets("MNIST_data/", one_hot=True)
    test_model(mnist_data.test)

3、Code Details

1、tf.equal()

https://cloud.tencent.com/developer/article/1406384

2、tf.train.get_checkpoint_state()

3、saver.restore()

https://blog.csdn.net/qq_37285386/article/details/88957558

猜你喜欢

转载自blog.csdn.net/qq_36627158/article/details/108321673