tensorflow.keras入门4

tensorflow.keras入门4-过拟合和欠拟合

简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。

过拟合和欠拟合

以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:
baseline模型结构为:10000-16-16-1
smaller_model模型结构为:10000-4-4-1
bigger_model模型结构为:10000-512-512-1
造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。
过拟合和欠拟合

解决方法–正则化

正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。tf.keras通常在损失函数后添加正则项,L1正则化和L2正则化。

l2_model = keras.models.Sequential([
	#权重l2正则化
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation=tf.nn.relu, input_shape=(10000,)),
    #权重l2正则化
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation=tf.nn.relu),
    keras.layers.Dense(1, activation=tf.nn.sigmoid)
])

l2_model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy', 'binary_crossentropy'])
l2_model_history = l2_model.fit(train_data, train_labels,
                                epochs=20,
                                batch_size=512,
                                validation_data=(test_data, test_labels),
                                verbose=2)

dropout

Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。

dpt_model = keras.models.Sequential([
    keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
    keras.layers.Dropout(0.3), #百分之30的神经元失效
    keras.layers.Dense(16, activation=tf.nn.relu),
    keras.layers.Dropout(0.7), #百分之70的神经元失效
    keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
dpt_model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy','binary_crossentropy'])
 
dpt_model_history = dpt_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)

总结

常用防止过拟合的方法有:
1.增加数据量
2.减少网络结构参数
3.正则化
4.dropout
5.数据扩增data-augmentation
6.批标准化

猜你喜欢

转载自blog.csdn.net/qq_35297368/article/details/112687375