Tensorflow笔记__使用mnist数据集并测试自己的手写图片

内容源于曹建老师的tensorflow笔记课程

源码链接:https://github.com/cj0012/AI-Practice-Tensorflow-Notes

测试图片下载:https://github.com/cj0012/AI-Practice-Tensorflow-Notes/blob/master/num.zip

主要包含四个文件,主要是mnist_forward.py,mnist_backward.py,mnist_test.py,mnist_app.py

定义前向传播过程 mnist_forward.py:

import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER_NODE = 500

# 定义神经网络参数,传入两个参数,一个是shape一个是正则化参数大小
def get_weight(shape,regularizer):
    # tf.truncated_normal截断的正态分布函数,超过标准差的重新生成
   w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))
   if regularizer != None:
        # 将正则化结果存入losses      tf.add_to_collection("losses",tf.contrib.layers.l2_regularizer(regularizer)(w))
   return w

# 定义偏置b,传入shape参数
def get_bias(shape):
    # 初始化为0
   b = tf.Variable(tf.zeros(shape))
   return b

# 定义前向传播过程,两个参数,一个是输入数据,一个是正则化参数
def forward(x,regularizer):
    # w1的维度就是[输入神经元大小,第一层隐含层神经元大小]
   w1 = get_weight([INPUT_NODE,LAYER_NODE],regularizer)
    # 偏置b参数,w的后一个参数相同
   b1 = get_bias(LAYER_NODE)
    # 激活函数
   y1 = tf.nn.relu(tf.matmul(x,w1)+b1)

   w2 = get_weight([LAYER_NODE,OUTPUT_NODE],regularizer)
   b2 = get_bias(OUTPUT_NODE)
   y = tf.matmul(y1,w2)+b2

   return y

   

定义反向传播过程 mnist_backward.py:

#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATCH_SIZE = 200
#学习率衰减的原始值
LEARNING_RATE_BASE = 0.1
# 学习率衰减率
LEARNING_RATE_DECAY = 0.99
# 正则化参数
REGULARIZER = 0.0001
# 训练轮数
STEPS = 50000
#这个使用滑动平均的衰减率
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"

def backward(mnist):
   #一共有多少个特征,784,一列
   x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
   y_ = tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
   # 给前向传播传入参数x和正则化参数计算出y的值
   y = mnist_forward.forward(x,REGULARIZER)
   # 初始化global—step,它会随着训练轮数增加
   global_step = tf.Variable(0,trainable=False)

   # softmax和交叉商一起运算的函数,logits传入是x*w,也就是y
   ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
   cem = tf.reduce_mean(ce)
   loss = cem + tf.add_n(tf.get_collection("losses"))

   learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step,
                                               mnist.train.num_examples/BATCH_SIZE,
                                               LEARNING_RATE_DECAY,
                                               staircase = True)

   train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step = global_step)

    # 滑动平均处理,可以提高泛华能力
   ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
   ema_op = ema.apply(tf.trainable_variables())
   # train_step和滑动平均计算ema_op放在同一个节点
   with tf.control_dependencies([train_step,ema_op]):
      train_op = tf.no_op(name="train")
        
   saver = tf.train.Saver()

   with tf.Session() as sess:
        
      init_op = tf.global_variables_initializer()
      sess.run(init_op)

      for i in range(STEPS):
         # mnist.train.next_batch()函数包含一个参数BATCH_SIZE,表示随机从训练集中抽取BATCH_SIZE个样本输入到神经网络
         # next_batch函数返回的是image的像素和标签label
         xs,ys = mnist.train.next_batch(BATCH_SIZE)
         # _,表示后面不使用这个变量
         _,loss_value,step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
            
         if i % 1000 == 0:
            print("Ater {} training step(s),loss on training batch is {} ".format(step,loss_value))
            saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)

def main():
    
   mnist = input_data.read_data_sets("./data",one_hot = True)
   backward(mnist)

if __name__ == "__main__":
   main()

定义测试部分 mnist_test.py:

#coding:utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward

TEST_INTERVAL_SECS = 5

def test(mnist):
    with tf.Graph().as_default() as g:
        # 占位符,第一个参数是tf.float32数据类型,第二个参数是shape,shape[0]=None表示输入维度任意,shpe[1]表示输入数据特征数
        x = tf.placeholder(tf.float32,shape = [None,mnist_forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32,shape = [None,mnist_forward.OUTPUT_NODE])
        """注意这里没有传入正则化参数,需要明确的是,在测试的时候不要正则化,不要dropout"""
        y = mnist_forward.forward(x,None)

        # 实例化可还原的滑动平均模型
        ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_restore = ema.variables_to_restore()
        saver = tf.train.Saver(ema_restore)

        # y计算的过程:xmnist.test.images10000×784的,最后输出的y10000×10的,y_:mnist.test.labels也是10000×10        correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        # tf.cast可以师兄数据类型的转换,tf.equal返回的只有TrueFalse
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

        while True:
            with tf.Session() as sess:
                # 加载训练好的模型
                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    # 恢复模型到当前会话
                    saver.restore(sess,ckpt.model_checkpoint_path)
                    # 恢复轮数
                    global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
                    # 计算准确率
                    accuracy_score = sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
                    print("After {} training step(s),test accuracy is: {} ".format(global_step,accuracy_score))
                else:
                    print("No chekpoint file found")
                    print(sess.run(y,feed_dict={x:mnist.test.images}))
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():

    mnist = input_data.read_data_sets("./data",one_hot=True)
    test(mnist)

if __name__== "__main__":
    main()

定义使用手写图片部分mnist_app.py:

import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backward

# 定义加载使用模型进行预测的函数
def restore_model(testPicArr):

    with tf.Graph().as_default() as tg:
        
        x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
        y = mnist_forward.forward(x,None)
        preValue = tf.argmax(y,1)
        # 加载滑动平均模型
        variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        with tf.Session() as sess:
            
            ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                # 恢复当前会话,ckpt中的值赋值给wb
                saver.restore(sess,ckpt.model_checkpoint_path)
                # 执行图计算
                preValue = sess.run(preValue,feed_dict={x:testPicArr})
                return preValue
            else:
                print("No checkpoint file found")
                return -1
# 图片预处理函数
def pre_pic(picName):
    # 先打开传入的原始图片
    img = Image.open(picName)
    # 使用消除锯齿的方法resize图片
    reIm = img.resize((28,28),Image.ANTIALIAS)
    # 变成灰度图,转换成矩阵
    im_arr = np.array(reIm.convert("L"))
    threshold = 50#对图像进行二值化处理,设置合理的阈值,可以过滤掉噪声,让他只有纯白色的点和纯黑色点
    for i in range(28):
        for j in range(28):
            im_arr[i][j] = 255-im_arr[i][j]
            if (im_arr[i][j]<threshold):
                im_arr[i][j] = 0
            else:
                im_arr[i][j] = 255
    # 将图像矩阵拉成1784列,并将值变成浮点型(像素要求的仕0-1的浮点型输入)
    nm_arr = im_arr.reshape([1,784])
    nm_arr = nm_arr.astype(np.float32)
    img_ready = np.multiply(nm_arr,1.0/255.0)

    return img_ready

def application():
    # input函数可以从控制台接受数字
    testNum = int(input("input the number of test images:"))
    # 使用循环来历遍需要测试的图片才结束
    for i in range(testNum):
        # input可以实现从控制台接收字符格式,图片存储路径
        testPic = input("the path of test picture:")
        # 将图片路径传入图像预处理函数中
        testPicArr = pre_pic(testPic)
        # 将处理后的结果输入到预测函数最后返回预测结果
        preValue = restore_model(testPicArr)
        print("The prediction number is :",preValue)

def main():
    application()

if __name__ == "__main__":
    main()

output:


The end.

猜你喜欢

转载自blog.csdn.net/li_haiyu/article/details/80846657
今日推荐