初识神经网络(helloworld)
要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)。我们将使用MNIST数据集。
在机器学习中,分类问题中的某个类别叫作类(class),数据点叫作样本(sample),与某个样本对应的类叫作标签(label)。
MNIST数据集已预先加载在Keras库中,其中包含4个NumPy数组,如代码清单2-1所示。
代码清单2-1 加载Keras中的MNIST数据集
from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。
图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。我们来看一下训练数据:
print(train_images.shape)
print(len(train_labels))
print(train_labels)
再来看一下测试数据:
print(test_images.shape)
print(len(test_labels))
print(test_labels)
工作流程如下:首先,将训练数据(train_images和train_labels)输入神经网络;然后,神经网络学习将图像和标签关联在一起;最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。
构建神经网络,如代码清单2-2所示
代码清单2-2 神经网络架构
#顺序模型,通过传入层的列表构造
model = keras.Sequential([
#全密集链接层 512个单元,激活函数为relu
layers.Dense(512, activation="relu"),
#全密集链接层 10个单元,激活函数为softmax
layers.Dense(10, activation="softmax")
])
神经