非常精简的Mnist分类,基于tensorflow框架

一、介绍

    基于tensorflow框架实现的Mnist数据分类。代码主要包括网络结构的搭建,训练超参数的导入和保存,损失函数地绘制等。不足之处是在网络结尾没用使用softmax函数,而直接使用了tanh输出了分类结果。下面请看代码的详细介绍

二、代码

  • 导入必要的包文件,需要的包我直接通过pycharm导入的,能导入的原因是采用了anaconda3底下的python.exe,新建工程的时候,从外部导入
# 需要使用到的包文件
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import argparse
import os
# 加上这一句能够使Plot绘制出来的图更精美
sns.set_style("whitegrid")
  • 训练参数设置,详细介绍请看代码注释,主要采用了argparse,该模块的好处是直接可以在运行时修改参数,比如:python main.py --data_dir= "**"
parser = argparse.ArgumentParser(description="Network for image classification")                
parser.add_argument('--data_dir', default='tem/data', help='Directory for training data')   # Mnist数据集存放位置
parser.add_argument('--result_dir', default='tem/result')                                   # 训练结果的存放
parser.add_argument('--model_dir', default='model/', help='the place of saving networks parameters')   #训练参数的存放地址
parser.add_argument('--batch_size', default=32)
parser.add_argument('--print_loss', default=10) # 每隔10次迭代打印损失值
parser.add_argument('--plot_loss', default=100) # 每隔100次迭代绘制损失函数曲线
parser.add_argument('--learning_rate', default=0.001, type=float) # 学习率,不易设置过大
parser.add_argument('--n_iterations', default=10000, type=int) # 迭代次数
args = parser.parse_args() # 将--*的*传递给arg,调用时直接使用args.data_dir这样的结构
  • 网络结构搭建
w_init = tf.random_normal_initializer(stddev=0.01)   # 权重w初始化,标准差为0.01,平均值0
def network(x): # 激活函数都为relu,除了输出
    layers1 = tf.layers.conv2d(x, 32, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init)      # 32个卷积核,3x3卷积核大小,步长为1,padding为'same',即输出大小为input/stride,向上取整
    layers2 = tf.layers.conv2d(layers1, 62, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init) 
    layers2_flatten = tf.contrib.layers.flatten(layers2)  # 将layers2的输出"磨平",降低相关维度,以供全连接层工作
    layers3 = tf.layers.dense(layers2_flatten, 200, activation=tf.nn.relu, kernel_initializer=w_init) # 200为全连接层单元个数,其它的痛卷积函数类似
    output = tf.layers.dense(layers3, 10, activation=tf.nn.tanh, kernel_initializer=w_init) # 使用tanh作为输出,比sigmoid好,因为sigmoid是有0项,不利于网络训练
    return output
  • 训练网络,详细介绍看注释
def training():
    input_x = tf.placeholder(tf.float32, [None, 28, 28, 1])  # 放置占位矩阵
    label_y = tf.placeholder(tf.float32, [None, 10])
    output_y = network(input_x)                              # 前向传播
    loss = tf.reduce_sum(tf.square(label_y-output_y))        # 计算同便签损失
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate).minimize(loss) # 使用Adam优化

    init_all_v = tf.global_variables_initializer()           # 张量初始化函数
    sess = tf.InteractiveSession()                           
    sess.run(init_all_v)                                     # 实行张量初始化
    saver = load_model(sess)                                 # 导入之前训练过的参数,如果没有则打印出来
    mnist = read_data_sets(args.data_dir, one_hot=True)      # 往指定目录读取Mnist数据集
    print('start training')                                   
    plot_loss = []                                           # 损失值缓存
    for i in range(args.n_iterations):                             
        batch_x, batch_y = mnist.train.next_batch(args.batch_size)      # 读取Batch_size
        batch_x = batch_x.reshape([args.batch_size, 28, 28, 1])         # 维度匹配   
        y = np.zeros([args.batch_size, 10])                             # 下面的操作是因为我读到的标签是6,8,9直接对应的图片的数字,所以将这些数字向量化,以便训练
        for j in range(args.batch_size):
            k = batch_y[j].astype(np.int)
            y[j, k] = 1.
        batch_y = y
        d_loss, _ = sess.run([loss, optimizer], feed_dict={input_x:batch_x, label_y:batch_y})  # 运行
        plot_loss.append(d_loss)

        if i % args.print_loss == 0 and i > 0:
            print('Iteration is : %d, Loss is: %f' % (i, d_loss))   # 打印损失
        if i % args.plot_loss == 0 and i > 0:            # 绘图
            plt.figure(figsize=(6*1.1618, 6))
            plt.plot(range(len(plot_loss)), plot_loss)
            plt.xlabel('iteration times')
            plt.ylabel('lost')
            plt.show()
        if i % 500 == 0 and i > 0:
            save_model(saver, sess, i)
  • 模块的导入与存储
def save_model(saver, sess, step):   # 存储模块
    saver.save(sess, os.path.join(args.model_dir, "classification"), global_step=step)
def load_model(sess):                # 导入模块 
    saver = tf.train.Saver()
    checkpoint = tf.train.get_checkpoint_state(args.model_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)
        print("Successfully loaded:", checkpoint.model_checkpoint_path)
    else:
        print("Could not find any old weights!")
    return saver
  • 主函数
def main(_):
    training()
if __name__ == "__main__":
    tf.app.run()

从上往下黏贴就行,贴到IDE下就可以运行,还可以打印损失函数

鬼知道为什么下降这么快,,,

猜你喜欢

转载自blog.csdn.net/monotonomo/article/details/80675795