机器学习 TensorFlow2.0 官方入门程序解说

本文介绍 机器学习 TensorFlow2.0 官方入门程序。我觉得这个入门程序不容易看懂,所以做一个解说。其实更简单易懂的入门程序是我的 机器学习 tensorflow 2 的hello world

机器学习 TensorFlow2.0 官方入门程序在 https://tensorflow.google.cn/overview/ ,有个中文翻译解释了的在 TensorFlow 2.0 极简教程,不到 20 行代码带你入门

下面我就把这个很短的代码列出了,其中也包含了我的注释。应该可以基本看懂,主要是数据库获取的准备数据,不清楚。后面我将解剖这些数据,现在就把大程序看懂吧。

import tensorflow as tf
mnist = tf.keras.datasets.mnist
#从数据库获取,培训数据,测试数据,并归1 化处理
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
#建立模型 设置网络层
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型 优化器 损失函数 指标
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# 进行5次培训,然后做测试
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

1:程序包含tensorflow程序包,然后导入数据库给你准备的数据,包含培训数据: x_train, y_train ,以及测试数据。

2:归1化处理,就是输入数据范围在0,1之间。

3:建立模型,设置网络层

4:编译模型,提供了优化器 损失函数 指标

5:导入数据,进行5次培训

6:用测试数据,进行测试,查看精度情况。、

程序运行结果是:

我们看到培训样本是60000个,然后5次培训过程,精度是97.63%

测试数据的结果也差不多,97.57%

程序简短,也容易理解。但对她的培训,测试数据是什么,一点也不清楚。

下面我们就剖析她的数据内容。

我把模型和培训代码,注释掉,其实也可以保留,只是比较慢一点,最后代码如下:

import tensorflow as tf
mnist = tf.keras.datasets.mnist
#从数据库获取,培训数据,测试数据,并归1 化处理
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
#建立模型 设置网络层
"""
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型 优化器 损失函数 指标
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# 进行5次培训,然后做测试
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)
"""
#显示数据结构
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
#显示结果数据内容
print(y_train)
print(y_test)


import matplotlib.pyplot as plt


# 这段代码就是显示x_train[0]单元的内容,图示化显示,你可以改0为0-59999之间的数据
#  对程序没影响,和上面的print 差不多
plt.figure()
plt.imshow(x_train[0])
plt.colorbar()
plt.grid(False)
plt.show()

# 这段代码也是显示数据内容的作用
# 显示5行5列 训练数据的内容x_train[i],标签的内容y_train[i]
plt.figure(figsize=(10,10))
# 25 个数据组
for i in range(25):
    #图是5行5列
    plt.subplot(5,5,i+1)    
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i], cmap=plt.cm.binary)
    plt.xlabel(y_train[i])
plt.show()

代码里包含了import matplotlib.pyplot as plt   

matplotlib是个画图程序,也不必了解太多,想细看可网上查。

我们看到x_train 是60000个样本,(每个样本28X28个数据), y_train 60000样本(每个样本只有一个数据)

x_test和 y_test 是10000个测试样本。

y_train, y_test 是0-9之间的数字。

再抽取1个 x_train 我们看到数据是0-255 间的数据,/255 就归1化了。

抽取25 个x_train 看图,原来是0-9 的草体,或手写体,下面有0-9的标签(y_train)。

我们训练的内容是手写体的0-9 数字图片,判断是哪个数字。

对比下 机器学习 TensorFlow2.0 教程-图像分类, 图像分类只是10种商品,除了数据库不一样外,其实内容都一样,但那里包含了数据解析,就是我这里后面分析的内容。

本文介绍到此。

发布了131 篇原创文章 · 获赞 112 · 访问量 19万+

猜你喜欢

转载自blog.csdn.net/leon_zeng0/article/details/102775172