一、MNIST
1.1 要求
需要实现的模型:1. 基于 Backward Propagation 算法的人工神经网络;2. 卷积神经网络。
在 MNIST 手写字符识别数据集(http://yann.lecun.com/exdb/mnist/)上对实现的两个模型进行实验测试,陈述其原理与结果。
1.2 环境
- Windows 10
- Conda 4.10.3
- Python 3.9
- PyTorch 1.9.1
- PyTorch-Lightning 1.4.9
- CUDA
- DataSpell 2021.3 EAP
1.3 准备
准备阶段,使用 PyTorch-Lightning 进行训练框架的搭建。
在{model_name}_main.py 入口脚本(例如 cnn_main.py)中设置 Global Seed 为 42,使用自定义的 MnistDataLoader 作为训练数据,使用 pl.Trainer()对模型进行训练,自定义是否使用 GPU、epoch 数等参数:
在{model_name}_model.py 模型脚本(例如 cnn_model.py)中定义 MnistDataLoader,用于下载 MNIST 数据集,并将其进行向量化、正则化,划分 9:1 的训练-验证集,并指定 batch_size 参数:
1.4 人工神经网络
1.4.1 神经网络结构
本实验中实现的人工神经网络由两个全连接层组成,拥有两个权重矩阵(weight matrix)与两个偏置向量(bias vector),每个全连接层后使用 Sigmoid 作为激活函数:
本实验对两个权重矩阵进行 Xavier 初始化:
在前向传播函数中,构建 tzq_pipeline 函数进行矩阵数据的运算,并返回预测结果 y_hat 与中间层结果 b,其中,当计算出 b 与 y_hat 时,对两者使用 Sigmoid 函数进行激活:
在每次训练步(training_step)中,对一个 batch 的数据进行预测,使用均方误差进行 loss 计算,并对网络参数进行反向传播(backward propagation):
在反向传播过程中,模型对均方误差进行求导,沿着梯度下降的方向进行权重更新,以达到凸函数极点;模型根据《机器学习(周志华)》5.3 误差逆传播算法中的权重更新方程对权重进行更新:
1.4.2 实验设置
模型使用均方误差(Mean Square Error)作为损失函数(Loss Function)。
模型使用精度(Accuracy)作为测试标准。
模型设置学习率为 1e-2。
模型使用 Sigmoid 函数作为激活函数:
模型设置 epoch 次数为 64。
模型设置 batch_size 为 16。
1.4.3 实验结果
在经过充分的训练后,使用测试数据集对模型进行测试,最终得到 68% 的预测正确率,达到了较高的预测精度。
1.5 卷积神经网络
1.5.1 神经网络结构
本实验设计的 CNN 网络由五层网络模块组成,其中第一、二个模块运算卷积操作,第三、四个模块的带有 ReLU 操作的线性层,第五个模块是输出结果的线性层:
模型使用 5×5 的卷积核对图像进行卷积,在卷积之后使用 ReLU 作为激活函数(Activation Function) ,并使用 size 为 2×2、stride 为 2 的池化核对进行最大池化操作。
1.5.2 实验设置
模型使用交叉熵(Cross Entropy)作为损失函数(Loss Function)。
模型使用精度(Accuracy)作为测试标准。
模型设置学习率为 1e-2。
模型使用 ReLU 函数作为激活函数,因为相对于 Sigmoid 与 Tanh,ReLU 函数具有以下优势:1. 在误差进行反向传播时,可以缓解梯度消失;2. ReLU 会使一部分神经元的输出为 0,造成了网络的稀疏性,并且减少了参数的相互依存关系,缓解了过拟合问题的发生;3. 相对于 Sigmoid 与 Tanh,ReLU 函数求导简单,计算量较小,节约本就不多的计算资源。
模型设置 epoch 次数为 64。
模型设置 batch_size 为 16。
模型使用 GPU 进行训练加速。
1.5.3 实验结果
在经过充分的训练后,使用测试数据集对模型进行测试,最终得到 97% 的预测正确率,达到了极高的预测精度。
代码文档
实验的代码在提交的压缩包中的“代码”文件夹下的“ANNwithPyTorch”文件夹下,其中:
data 文件夹中包含了预下载的 MNIST 数据集;
ann_model.py 中包含了 ANN 模型代码与 MNIST 数据集处理代码,ann_main.py 是 ANN 实验的入口脚本;
cnn_model.py 中包含了 CNN 模型代码与 MNIST 数据集处理代码,cnn_main.py 是 CNN 实验的入口脚本。