MNIST手写数字识别—— 02 数据集
MNIST 数据集
MNIST图像数据集使用形如[28,28]的二阶数组来表示每张图像,数组中的每个元素对应一个像素点。该数据集中的图像都是256阶灰度图,像素值0表示白色(背景),255表示黑色(前景)。由于每张图像的尺寸都是28x28像素,为了方便连续存储,我们可以将形如[28,28]的二阶数组“摊平”成形如[784]的一阶数组。数组中的784个元素共同组成了一个784维的向量。
More info: http://yann.lecun.com/exdb/mnist/
使用 tf.Keras 加载 MNIST 数据集
tf.keras.datasets.mnist.load_data(path=‘mnist.npz’)
Arguments:
- path: 本地缓存 MNIST 数据集(mnist.npz)的相对路径(~/.keras/datasets)
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
详情查看 mnist.load_data API 文档
# tensorflow2.0的数据集集成到keras高级接口之中
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(15):
plt.subplot(3, 5, i+1) # 绘制前15个手写体数字,以3行5列子图形式展示
plt.tight_layout() # 自动适配子图尺寸,看起来不拥挤
plt.imshow(x_train[i], cmap='Greys') # 使用灰色显示像素灰度值
plt.title("Label: {}".format(y_train[i])) # 设置标签为子图标题
plt.xticks([]) # 删除x轴标记 否则会自动标上坐标
plt.yticks([]) # 删除y轴标记
x_train[0]