一起养成写作习惯!这是我参与「掘金日新计划 · 4 月更文挑战」的第13天,点击查看活动详情。
构建深度神经网络提高模型准确性
我们在之前模型中使用的神经网络在输入层和输出层之间只有一个隐藏层。在本节中,我们将学习在神经网络中使用多个隐藏层(因此称为深度神经网络),以探究网络深度对模型性能的影响。
深度神经网络意味着在输入层和输出层间存在多个隐藏层。多个隐藏层确保神经网络可以学习输入和输出之间的复杂非线性关系,而简单的神经网络则无法完成这样的需求。经典深度神经网络架构如下所示:
通过在输入和输出层之间添加多个隐藏层来构建深度神经网络架构,步骤如下所示。
- 加载数据集并对数据集进行缩放:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
num_pixels = x_train.shape[1] * x_train.shape[2]
x_train = x_train.reshape(-1, num_pixels).astype('float32')
x_test = x_test.reshape(-1, num_pixels).astype('float32')
x_train = x_train / 255.
x_test = x_test / 255.
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]
复制代码
- 在输入和输出层之间使用多个隐藏层构建模型:
model = Sequential()
model.add(Dense(512, input_dim=num_pixels, activation='relu'))
model.add(Dense(1024, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
复制代码
模型体系结构的相关模型信息,如下所示:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 512) 401920
_________________________________________________________________
dense_1 (Dense) (None, 1024) 525312
_________________________________________________________________
dense_2 (Dense) (None, 64) 65600
_________________________________________________________________
dense_3 (Dense) (None, 10) 650
=================================================================
Total params: 993,482
Trainable params: 993,482
Non-trainable params: 0
_________________________________________________________________
复制代码
由于深度神经网络架构中包含更多的隐藏层,因此模型中也包含更多的参数。
- 建立了模型之后,就可以编译并拟合模型:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
history = model.fit(x_train, y_train,
validation_data=(x_test, y_test),
epochs=50,
batch_size=64,
verbose=1)
复制代码
训练完成的模型的准确度约为 98.9%
,比之前所用模型架构所得到的精确度略好,这是由于 MNIST
数据集相对简单。训练和测试损失及准确率如下:
在上图中可以看到,训练数据集准确率在很大程度上优于测试数据集准确率,这表明深度神经网络对训练数据进行了过度拟合。在之后的学习中,我们将了解避免训练数据过拟合的方法。