基于PyTorch与PyTorch-Lightning进行Backward Propagation人工神经网络模型与CNN模型的构建

一、MNIST

1.1 要求

需要实现的模型:1. 基于 Backward Propagation 算法的人工神经网络;2. 卷积神经网络。

在 MNIST 手写字符识别数据集(http://yann.lecun.com/exdb/mnist/)上对实现的两个模型进行实验测试,陈述其原理与结果。

1.2 环境

  • Windows 10
  • Conda 4.10.3
  • Python 3.9
  • PyTorch 1.9.1
  • PyTorch-Lightning 1.4.9
  • CUDA
  • DataSpell 2021.3 EAP

1.3 准备

准备阶段,使用 PyTorch-Lightning 进行训练框架的搭建。

在{model_name}_main.py 入口脚本(例如 cnn_main.py)中设置 Global Seed 为 42,使用自定义的 MnistDataLoader 作为训练数据,使用 pl.Trainer()对模型进行训练,自定义是否使用 GPU、epoch 数等参数:

在{model_name}_model.py 模型脚本(例如 cnn_model.py)中定义 MnistDataLoader,用于下载 MNIST 数据集,并将其进行向量化、正则化,划分 9:1 的训练-验证集,并指定 batch_size 参数:

1.4 人工神经网络

1.4.1 神经网络结构

本实验中实现的人工神经网络由两个全连接层组成,拥有两个权重矩阵(weight matrix)与两个偏置向量(bias vector),每个全连接层后使用 Sigmoid 作为激活函数:

本实验对两个权重矩阵进行 Xavier 初始化:

在前向传播函数中,构建 tzq_pipeline 函数进行矩阵数据的运算,并返回预测结果 y_hat 与中间层结果 b,其中,当计算出 b 与 y_hat 时,对两者使用 Sigmoid 函数进行激活:

在每次训练步(training_step)中,对一个 batch 的数据进行预测,使用均方误差进行 loss 计算,并对网络参数进行反向传播(backward propagation):

在反向传播过程中,模型对均方误差进行求导,沿着梯度下降的方向进行权重更新,以达到凸函数极点;模型根据《机器学习(周志华)》5.3 误差逆传播算法中的权重更新方程对权重进行更新:

1.4.2 实验设置

模型使用均方误差(Mean Square Error)作为损失函数(Loss Function)。

模型使用精度(Accuracy)作为测试标准。

模型设置学习率为 1e-2。

模型使用 Sigmoid 函数作为激活函数:

模型设置 epoch 次数为 64。

模型设置 batch_size 为 16。

1.4.3 实验结果

在经过充分的训练后,使用测试数据集对模型进行测试,最终得到 68% 的预测正确率,达到了较高的预测精度。

1.5 卷积神经网络

1.5.1 神经网络结构

本实验设计的 CNN 网络由五层网络模块组成,其中第一、二个模块运算卷积操作,第三、四个模块的带有 ReLU 操作的线性层,第五个模块是输出结果的线性层:

模型使用 5×5 的卷积核对图像进行卷积,在卷积之后使用 ReLU 作为激活函数(Activation Function) ,并使用 size 为 2×2、stride 为 2 的池化核对进行最大池化操作。

1.5.2 实验设置

模型使用交叉熵(Cross Entropy)作为损失函数(Loss Function)。

模型使用精度(Accuracy)作为测试标准。

模型设置学习率为 1e-2。

模型使用 ReLU 函数作为激活函数,因为相对于 Sigmoid 与 Tanh,ReLU 函数具有以下优势:1. 在误差进行反向传播时,可以缓解梯度消失;2. ReLU 会使一部分神经元的输出为 0,造成了网络的稀疏性,并且减少了参数的相互依存关系,缓解了过拟合问题的发生;3. 相对于 Sigmoid 与 Tanh,ReLU 函数求导简单,计算量较小,节约本就不多的计算资源。

模型设置 epoch 次数为 64。

模型设置 batch_size 为 16。

模型使用 GPU 进行训练加速。

1.5.3 实验结果

在经过充分的训练后,使用测试数据集对模型进行测试,最终得到 97% 的预测正确率,达到了极高的预测精度。

代码文档

实验的代码在提交的压缩包中的“代码”文件夹下的“ANNwithPyTorch”文件夹下,其中:

data 文件夹中包含了预下载的 MNIST 数据集;

ann_model.py 中包含了 ANN 模型代码与 MNIST 数据集处理代码,ann_main.py 是 ANN 实验的入口脚本;

cnn_model.py 中包含了 CNN 模型代码与 MNIST 数据集处理代码,cnn_main.py 是 CNN 实验的入口脚本。

猜你喜欢

转载自blog.csdn.net/newlw/article/details/125008249