1. Introduction
今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第二天,主要学习训练网络。本 blog 主要记录一个学习的路径以及学习资料的汇总。
注意:这是用 Python 2.7 版本写的代码
第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108245969
第二天(训练网络):https://blog.csdn.net/qq_36627158/article/details/108315239
第三天(测试网络):https://blog.csdn.net/qq_36627158/article/details/108321673
第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108397018
2. Code(mnist_train.py)
import tensorflow as tf
import matplotlib.pyplot as plt
import mnist_lenet
from tensorflow.examples.tutorials.mnist import input_data
batch_size = 64
learn_rate = 0.01
iteration = 1500
def train_model(train_dataset):
images_holder = tf.placeholder(
dtype=tf.float32,
shape=[batch_size, 28, 28, 1]
)
labels_holder = tf.placeholder(
dtype=tf.float32,
shape=[batch_size, 10]
)
label_predict = mnist_lenet.build_model_and_forward(images_holder)
# get every single image's loss
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=label_predict,
labels=tf.argmax(labels_holder, axis=1)
)
# get the mean loss in a batch of image
loss = tf.reduce_mean(cross_entropy)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learn_rate)
train_update_op = optimizer.minimize(loss)
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print("Start training:")
loss_plt = []
for i in range(iteration):
batch_images, batch_labels = train_dataset.next_batch(batch_size)
# batch_images.shape(batch_size, 784)
# batch_labels.shape(batch_size, 10)
batch_images_reshaped = tf.reshape(
tensor=batch_images,
shape=[batch_size, 28, 28, 1]
)
loss_value, _ = sess.run(
[loss, train_update_op],
feed_dict={
images_holder: batch_images_reshaped.eval(),
labels_holder: batch_labels
}
)
if (i+1) % 50 == 0:
print "After", (i+1), "iteration, loss on training batch is", loss_value
loss_plt.append(loss_value)
saver.save(sess, "models/model.ckpt", global_step=i+1)
print("End training")
plt.plot(loss_plt, color=(0, 0, 0), label='loss')
plt.legend()
plt.show()
if __name__ == '__main__':
mnist_data = input_data.read_data_sets('MNIST_data/', one_hot=True)
if mnist_data != None:
print("Load data completely!")
train_model(mnist_data.train)
3. Materials
1、tensorflow 官方文档
https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/
2、input_data.py
4、Code Details
1、注意,TensorFlow 2.0 版本没有 tensorflow.examples.tutorials 模块
解决方案:
2、input_data 中 read_data_sets()
- https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/input_data.py
- https://blog.csdn.net/KID_yuan/article/details/89040245
one_hot 编码参数的作用:https://blog.csdn.net/weiyumeizi/article/details/81502471
截取了其中的 read_data_sets() 函数的源码
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:
def fake():
return _DataSet([], [],
fake_data=True,
one_hot=one_hot,
dtype=dtype,
seed=seed)
train = fake()
validation = fake()
test = fake()
return _Datasets(train=train, validation=validation, test=test)
if not source_url: # empty string check
source_url = DEFAULT_SOURCE_URL
train_images_file = 'train-images-idx3-ubyte.gz'
train_labels_file = 'train-labels-idx1-ubyte.gz'
test_images_file = 't10k-images-idx3-ubyte.gz'
test_labels_file = 't10k-labels-idx1-ubyte.gz'
local_file = _maybe_download(train_images_file, train_dir,
source_url + train_images_file)
with gfile.Open(local_file, 'rb') as f:
train_images = _extract_images(f)
local_file = _maybe_download(train_labels_file, train_dir,
source_url + train_labels_file)
with gfile.Open(local_file, 'rb') as f:
train_labels = _extract_labels(f, one_hot=one_hot)
local_file = _maybe_download(test_images_file, train_dir,
source_url + test_images_file)
with gfile.Open(local_file, 'rb') as f:
test_images = _extract_images(f)
local_file = _maybe_download(test_labels_file, train_dir,
source_url + test_labels_file)
with gfile.Open(local_file, 'rb') as f:
test_labels = _extract_labels(f, one_hot=one_hot)
if not 0 <= validation_size <= len(train_images):
raise ValueError(
'Validation size should be between 0 and {}. Received: {}.'.format(
len(train_images), validation_size))
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = _DataSet(train_images, train_labels, **options)
validation = _DataSet(validation_images, validation_labels, **options)
test = _DataSet(test_images, test_labels, **options)
return _Datasets(train=train, validation=validation, test=test)
3、tf.placeholder()
https://blog.csdn.net/kdongyi/article/details/82343712
4、tf.nn.sparse_softmax_cross_entropy_with_logits()
https://blog.csdn.net/ZJRN1027/article/details/80199248
5、tf.argmax()
返回最大的那个数值所在的下标
https://blog.csdn.net/qq575379110/article/details/70538051/
6、tf.reduce_mean()
https://blog.csdn.net/dcrmg/article/details/79797826
7、tf.train.GradientDescentOptimizer()
https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/train/GradientDescentOptimizer
https://www.cnblogs.com/smallredness/p/11203250.html
8、tf.train.Saver()
https://blog.csdn.net/yz19930510/article/details/80324389
9、tf.global_variables_initializer()
- https://blog.csdn.net/qq_37285386/article/details/89054090
- https://blog.csdn.net/qq_26591517/article/details/80601225?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param
10、mnist.train.next_batch到底完成了什么工作?
11、这段代码中的 “_” 是什么意思?
loss_value, _ = sess.run(
[loss, train_update_op],
feed_dict={
images_holder: batch_images_reshaped.eval(),
labels_holder: batch_labels
}
我试着输出了一下,就是 None
12、这段代码中的 batch_images_reshaped 要使用 eval() 函数
loss_value, _ = sess.run(
[loss, train_update_op],
feed_dict={
images_holder: batch_images_reshaped.eval(),
labels_holder: batch_labels
}
如果不用,会报错:TensorFlow 报错 TypeError: The value of a feed cannot be a tf.Tensor object