全连接层--初阶文本分类

文本分类项目实战:电影评论(TensorFlow2.0)

本文会将文本形式的影评分为“正面”或“负面”影评。这是一个二元分类(又称为两类分类)的示例,也是一种重要且广泛适用的机器学习问题。

我们将使用包含来自网络电影数据库的50,000条电影评论文本的IMDB数据集,这些被分为25,000条训练评论和25,000条评估评论,训练和测试集是平衡的,这意味着它们包含相同数量的正面和负面评论。

本章节使用tf.keras,这是一个高级API,用于在TensorFlow中构建和训练模型,有关使用tf.keras的更高级文本分类教程,请参阅MLCC文本分类指南。

import tensorflow as tf
from tensorflow import keras
import numpy as np

1.下载IMDB数据集

IMDB sentiment classification dataset.

  • 1.Module: tf.keras.datasets.imdb
  • 2.Functions
    • get_word_index(…): Retrieves the dictionary mapping word indices back to words.
    • load_data(…): Loads the IMDB dataset.

IMDB数据集与TensorFlow一起打包,它已经被预处理,使得评论(单词序列)已被转换为整数序列,其中每个整数表示字典中的特定单词。

以下代码将IMDB数据集下载到您的计算机(如果您已经下载了它,则使用缓存副本):

imdb = keras.datasets.imdb
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数num_words = 10000保留训练数据中最常出现的10000个单词,丢弃罕见的单词以保持数据的大小可管理,低频词将被丢弃。

2.探索数据

我们花一点时间来理解数据的格式,数据集经过预处理:每个示例都是一个整数数组,表示电影评论的单词。每个标签都是0或1的整数值,其中0表示负面评论,1表示正面评论。

print("Training entries(条目): {}, labels: {}".format(len(train_data), len(train_labels)))

Training entries(条目): 25000, labels: 25000

评论文本已转换为整数,其中每个整数表示字典中的特定单词。以下是第一篇评论的内容,以及第一篇评论内容的评论的性质:

print(train_data[0], train_labels[0])

[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32] 1

电影评论的长度可能不同,以下代码显示了第一次和第二次评论中的字数。由于对神经网络的输入必须是相同的长度,我们稍后需要解决此问题。

len(train_data[0]), len(train_data[1])



(218, 189)

2.1将整数转换成文本

了解如何将整数转换回文本可能很有用。

在这里,我们将创建一个辅助函数来查询包含整数到字符串映射的字典对象:

# 将单词映射到整数索引的字典(这里就相当于建立了一个字典,里边的单词与整数一一对应,便于查询)
# 返回的是一个字典:用于将数字转向单词
word_index = imdb.get_word_index()


# 第一个指数是保留的
# # key值不变,value值加3,并新增了4个键值对
Word_index = {k:(v+3) for k, v in word_index.items()}
word_index["<PAD>"] = 0  # 用来将每一个sentence扩充到同等长度(PAD:padding)
word_index["<START>"] = 1
word_index["<UNK>"] = 2  # unknown:未知,可能是生僻单词或是人名
word_index["<UNUSED>"] = 3

# 将键值对的键与值互换
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])

# # 转译为原句
def decode_review(text):
    return ' '.join([reverse_word_index.get(i, '?') for i in text])

字典的items()方法返回的是:可遍历的(键, 值) 元组数组。

现在我们可以使用decode_review函数显示首条评论的文本:

decode_review(train_data[0])



"<START> this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert <UNK> is an amazing actor and now the same being director <UNK> father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for <UNK> and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all"

3.预处理数据

影评(以整数数组存储)必须转换为张量,然后才能馈送到神经网络中,通过以下两种方法可以实现转换:review:评论

  • 对数组进行独热编码(one-hot),将它们转化为由0和1构成的向量。例如序列[3, 5](两个单词,一个用3表示一个用5表示)变成一个10000维的向量(即将每一个单词变为一个使用10000维的向量表示),除索引3和5转换为1之外,其余全转换为0。然后将它作为网络的第一层,一个可以处理浮点向量数据的密集层,不过这种方法会占用大量的内存,需要一个大小为num_words * num_reviews的矩阵。这种方式仅统计是否出现和不统计词频。
  • 或者,我们可以使用填充数组(扩展整数数组),使它们都具有相同的长度,这样每一个sequence会有共同的max_length(256),然后创建一个形状为max_length * num_reviews的整数张量。我们可以使用能够处理此形状数据的嵌入层作为网络中的第一层。

这里使用第二种方法。

由于电影的评论长度必须相同,这里使用pad_sequences函数来使长度标准化:在sequence后面扩充

train_data = keras.preprocessing.sequence.pad_sequences(train_data,
                                                        value=word_index["<PAD>"],
                                                        padding='post',  # 在sequence后面扩充
                                                        maxlen=256)

test_data = keras.preprocessing.sequence.pad_sequences(test_data,
                                                       value=word_index["<PAD>"],
                                                       padding='post',
                                                       maxlen=256)

接下来看一下样本的长度:

len(train_data[0]), len(train_data[1])



(256, 256)

并检查一下首条评论(当前已填充):

train_data[0]



array([   1,   14,   22,   16,   43,  530,  973, 1622, 1385,   65,  458,
       4468,   66, 3941,    4,  173,   36,  256,    5,   25,  100,   43,
        838,  112,   50,  670,    2,    9,   35,  480,  284,    5,  150,
          4,  172,  112,  167,    2,  336,  385,   39,    4,  172, 4536,
       1111,   17,  546,   38,   13,  447,    4,  192,   50,   16,    6,
        147, 2025,   19,   14,   22,    4, 1920, 4613,  469,    4,   22,
         71,   87,   12,   16,   43,  530,   38,   76,   15,   13, 1247,
          4,   22,   17,  515,   17,   12,   16,  626,   18,    2,    5,
         62,  386,   12,    8,  316,    8,  106,    5,    4, 2223, 5244,
         16,  480,   66, 3785,   33,    4,  130,   12,   16,   38,  619,
          5,   25,  124,   51,   36,  135,   48,   25, 1415,   33,    6,
         22,   12,  215,   28,   77,   52,    5,   14,  407,   16,   82,
          2,    8,    4,  107,  117, 5952,   15,  256,    4,    2,    7,
       3766,    5,  723,   36,   71,   43,  530,  476,   26,  400,  317,
         46,    7,    4,    2, 1029,   13,  104,   88,    4,  381,   15,
        297,   98,   32, 2071,   56,   26,  141,    6,  194, 7486,   18,
          4,  226,   22,   21,  134,  476,   26,  480,    5,  144,   30,
       5535,   18,   51,   36,   28,  224,   92,   25,  104,    4,  226,
         65,   16,   38, 1334,   88,   12,   16,  283,    5,   16, 4472,
        113,  103,   32,   15,   16, 5345,   19,  178,   32,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0])

4.构建模型

神经网络由堆叠的层来构建,这需要从两个主要的方面来进行体系结构决策

  • 要在模型里使用多少个层?
  • 针对每一个层使用多少个隐藏单元(神经元个数)?

在本例中,输入的数据由字词-索引数组构成、要预测的标签是0或1.接下来,我们为此问题构建一个模型:

"""
输入数据是单词组合,标签是0或者1
先进行数据稀疏稠密化,因为sequence里面的word_index值是[0~10000]内稀疏的,所以将每一个单词用一个16维的向量代替
input(1024,256) => output(1024,256,16)
再通过均值的池化层,将每一个sequence做均值,类似于将单词合并 ;input(1024,256,16),output(1024,16)
全连接层采用relu激活函数;input(1024,16) => output(1024,16)
全连接层采用sigmoid激活函数;input(1024,16) => output(1024,1)
"""
# 输入的形状是用于判定电影评论的性质的词汇数目(10000词)(其他词就当做是一些生僻词或者人名)
vocab_size = 10000

model = keras.Sequential()  # 创建一个池子,往里边线性堆叠模型
model.add(keras.layers.Embedding(vocab_size, 16))  # 输出的160000个参数,是由10000个单词分别和16个神经元相连,连接线的数量
model.add(keras.layers.GlobalAveragePooling1D())
model.add(keras.layers.Dense(16, activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))

model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_1 (Embedding)      (None, None, 16)          160000    
_________________________________________________________________
global_average_pooling1d_1 ( (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 16)                272       
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 17        
=================================================================
Total params: 160,289
Trainable params: 160,289
Non-trainable params: 0
_________________________________________________________________

keras.Sequential()

这些层按顺序堆叠以构建分类器:

  • 1.第一层是嵌入(embedding)层。该层采用整数编码的词汇表,并查找每个词索引的嵌入向量(embedding vector)。这些向量是通过模型学习得到的。向量向输出数组增加了一个维度。得到的维度是(batch, sequence, embedding)。
  • 2.接下来,GlobalAveragePooling1D将通过对序列维度求平均值来为每一个样本返回一个定长的输出向量。这允许模型以尽可能最简单的方式处理变长(各种长度的)输入。
  • 3.该长度固定的输出向量会传入一个全连接层(Dense层)(包含16个隐藏单元)
  • 4.最后一层与单个输出节点密集连接。应用sigmoid激活函数之后,结果是介于0到1之间的浮点数,表示概率或者置信度

4.1隐藏单元

上述模型在输入输出之间有两个中间层(也称为“隐藏”层)。输出(单元、节点或神经元)的数量即为层表示空间的维度,换句话说,是学习内部表示时网络所允许的自由度

如果模型具有更多的隐藏单元(更高维度的表示空间)和/或更多的层,则可以学习到更复杂的表示。但是这会使网络的计算成本更高,并且可能导致学习不到需要的模式(即不是需要的模式:一些能够在训练数据上而不是测试数据上改善性能的模式)。这被称之为过拟合(overfitting)。接下来会进行探究。

4.2损失函数与优化器

一个模型需要损失函数和优化器来进行训练。由于这是一个二分类的问题,且模型输出概率值(一个使用sigmoid激活函数的单一单元层),我们将使用binary_crossentropy(二元交叉熵)这个损失函数。

这不是损失函数的唯一选择,例如,你可以选择mean_squared_error(均方误差)。但是,一般来说binary_crossentropy更适合处理概率–它能够度量概率分布之间的“距离”,或者我们在示例中,指的是度量ground_truth分布与预测值之间的“距离”(真实分布与预测值之间的距离)。

在后边探索回归问题的时候(比如预测房子的价格)时,将会使用另一种称为均方差的损失函数

现在配置模型来使用优化器和损失函数

"""
因为采用了sigmoid激活函数,所以损失函数不能用mse均方误差,因为在sigmoid函数的两端梯度很小,会使w和b更新很慢,
所以采用交叉熵代价函数(cross-entropy cost function)
交叉熵代价函数具有非负性和当真实输出与期望输出相近的时候,代价函数接近于零.
"""
model.compile(optimizer='adam',
              loss = 'binary_crossentropy',
              metrics=['accuracy'])

5.创建验证集

在训练时,我们想要检查模型在以前没有见过的数据上的准确性(accuracy)。通过从原始训练数据中分离10000个示例来创建验证集。(为什么不立即使用测试集?我们的目标是仅使用训练数据开发和调整我们的模型,然后只使用一次测试数据来评估我们的准确性(accuracy))

x_val = train_data[:10000]  # 一共有5万,这一部分是10000之前的?
partial_x_train = train_data[10000:]

y_val = train_labels[:10000]
partial_y_train = train_labels[10000:]

6.训练模型

以512个样本的这样的小批量(mini_batch)数据,迭40个周期(epoch),这是指对x_rain和y_train张量中所有样本的40次迭代。在训练期间,监测模型在来自验证集中的10000个样本上的损失(loss)和准确性(accuracy):

history = model.fit(partial_x_train,
                    partial_y_train,
                    epochs=40,
                    batch_size=512,
                    validation_data=(x_val, y_val),
                    verbose=1)

Train on 15000 samples, validate on 10000 samples
Epoch 1/40
15000/15000 [==============================] - 7s 482us/sample - loss: 0.6926 - accuracy: 0.5099 - val_loss: 0.6915 - val_accuracy: 0.4951
Epoch 2/40
15000/15000 [==============================] - 3s 176us/sample - loss: 0.6884 - accuracy: 0.6069 - val_loss: 0.6842 - val_accuracy: 0.6856
Epoch 3/40
15000/15000 [==============================] - 3s 180us/sample - loss: 0.6763 - accuracy: 0.6802 - val_loss: 0.6690 - val_accuracy: 0.7178
Epoch 4/40
15000/15000 [==============================] - 3s 187us/sample - loss: 0.6542 - accuracy: 0.7407 - val_loss: 0.6437 - val_accuracy: 0.7583
Epoch 5/40
15000/15000 [==============================] - 3s 174us/sample - loss: 0.6213 - accuracy: 0.7821 - val_loss: 0.6089 - val_accuracy: 0.7829
Epoch 6/40
15000/15000 [==============================] - 3s 173us/sample - loss: 0.5791 - accuracy: 0.8080 - val_loss: 0.5680 - val_accuracy: 0.8043
Epoch 7/40
15000/15000 [==============================] - 3s 186us/sample - loss: 0.5323 - accuracy: 0.8266 - val_loss: 0.5240 - val_accuracy: 0.8214
Epoch 8/40
15000/15000 [==============================] - 3s 198us/sample - loss: 0.4844 - accuracy: 0.8474 - val_loss: 0.4826 - val_accuracy: 0.8346
Epoch 9/40
15000/15000 [==============================] - 3s 197us/sample - loss: 0.4398 - accuracy: 0.8617 - val_loss: 0.4446 - val_accuracy: 0.8441
Epoch 10/40
15000/15000 [==============================] - 3s 204us/sample - loss: 0.4006 - accuracy: 0.8743 - val_loss: 0.4134 - val_accuracy: 0.8520
Epoch 11/40
15000/15000 [==============================] - 3s 192us/sample - loss: 0.3667 - accuracy: 0.8835 - val_loss: 0.3877 - val_accuracy: 0.8611
Epoch 12/40
15000/15000 [==============================] - 3s 193us/sample - loss: 0.3385 - accuracy: 0.8902 - val_loss: 0.3679 - val_accuracy: 0.8641
Epoch 13/40
15000/15000 [==============================] - 3s 176us/sample - loss: 0.3147 - accuracy: 0.8963 - val_loss: 0.3519 - val_accuracy: 0.8655
Epoch 14/40
15000/15000 [==============================] - 3s 186us/sample - loss: 0.2941 - accuracy: 0.9015 - val_loss: 0.3378 - val_accuracy: 0.8721
Epoch 15/40
15000/15000 [==============================] - 3s 196us/sample - loss: 0.2762 - accuracy: 0.9070 - val_loss: 0.3270 - val_accuracy: 0.8758
Epoch 16/40
15000/15000 [==============================] - 3s 201us/sample - loss: 0.2608 - accuracy: 0.9115 - val_loss: 0.3180 - val_accuracy: 0.8786
Epoch 17/40
15000/15000 [==============================] - 3s 191us/sample - loss: 0.2468 - accuracy: 0.9170 - val_loss: 0.3111 - val_accuracy: 0.8791
Epoch 18/40
15000/15000 [==============================] - 3s 205us/sample - loss: 0.2341 - accuracy: 0.9214 - val_loss: 0.3050 - val_accuracy: 0.8811
Epoch 19/40
15000/15000 [==============================] - 3s 192us/sample - loss: 0.2226 - accuracy: 0.9245 - val_loss: 0.3008 - val_accuracy: 0.8815
Epoch 20/40
15000/15000 [==============================] - 3s 194us/sample - loss: 0.2122 - accuracy: 0.9285 - val_loss: 0.2976 - val_accuracy: 0.8816
Epoch 21/40
15000/15000 [==============================] - 3s 209us/sample - loss: 0.2026 - accuracy: 0.9313 - val_loss: 0.2932 - val_accuracy: 0.8836
Epoch 22/40
15000/15000 [==============================] - 3s 178us/sample - loss: 0.1930 - accuracy: 0.9354 - val_loss: 0.2906 - val_accuracy: 0.8840
Epoch 23/40
15000/15000 [==============================] - 3s 177us/sample - loss: 0.1844 - accuracy: 0.9397 - val_loss: 0.2888 - val_accuracy: 0.8838
Epoch 24/40
15000/15000 [==============================] - 3s 183us/sample - loss: 0.1772 - accuracy: 0.9425 - val_loss: 0.2877 - val_accuracy: 0.8841
Epoch 25/40
15000/15000 [==============================] - 3s 187us/sample - loss: 0.1690 - accuracy: 0.9461 - val_loss: 0.2872 - val_accuracy: 0.8849
Epoch 26/40
15000/15000 [==============================] - 3s 202us/sample - loss: 0.1623 - accuracy: 0.9493 - val_loss: 0.2860 - val_accuracy: 0.8845
Epoch 27/40
15000/15000 [==============================] - 3s 187us/sample - loss: 0.1552 - accuracy: 0.9514 - val_loss: 0.2872 - val_accuracy: 0.8852
Epoch 28/40
15000/15000 [==============================] - 3s 183us/sample - loss: 0.1493 - accuracy: 0.9545 - val_loss: 0.2868 - val_accuracy: 0.8857
Epoch 29/40
15000/15000 [==============================] - 3s 179us/sample - loss: 0.1432 - accuracy: 0.9572 - val_loss: 0.2875 - val_accuracy: 0.8860
Epoch 30/40
15000/15000 [==============================] - 3s 180us/sample - loss: 0.1379 - accuracy: 0.9583 - val_loss: 0.2882 - val_accuracy: 0.8862
Epoch 31/40
15000/15000 [==============================] - 3s 180us/sample - loss: 0.1322 - accuracy: 0.9608 - val_loss: 0.2897 - val_accuracy: 0.8848
Epoch 32/40
15000/15000 [==============================] - 3s 185us/sample - loss: 0.1276 - accuracy: 0.9631 - val_loss: 0.2918 - val_accuracy: 0.8860
Epoch 33/40
15000/15000 [==============================] - 3s 182us/sample - loss: 0.1222 - accuracy: 0.9653 - val_loss: 0.2931 - val_accuracy: 0.8839
Epoch 34/40
15000/15000 [==============================] - 3s 186us/sample - loss: 0.1181 - accuracy: 0.9667 - val_loss: 0.2951 - val_accuracy: 0.8857
Epoch 35/40
15000/15000 [==============================] - 3s 181us/sample - loss: 0.1127 - accuracy: 0.9685 - val_loss: 0.2972 - val_accuracy: 0.8847
Epoch 36/40
15000/15000 [==============================] - 3s 192us/sample - loss: 0.1086 - accuracy: 0.9700 - val_loss: 0.2997 - val_accuracy: 0.8844
Epoch 37/40
15000/15000 [==============================] - 3s 192us/sample - loss: 0.1044 - accuracy: 0.9718 - val_loss: 0.3024 - val_accuracy: 0.8838
Epoch 38/40
15000/15000 [==============================] - 3s 185us/sample - loss: 0.1005 - accuracy: 0.9733 - val_loss: 0.3051 - val_accuracy: 0.8833
Epoch 39/40
15000/15000 [==============================] - 3s 176us/sample - loss: 0.0971 - accuracy: 0.9745 - val_loss: 0.3095 - val_accuracy: 0.8806
Epoch 40/40
15000/15000 [==============================] - 3s 195us/sample - loss: 0.0935 - accuracy: 0.9763 - val_loss: 0.3113 - val_accuracy: 0.8818

释义连接:

Keras model.fit() 函数

7.评估模型

模型的表现:将返回两个值:损失值(loss)(一个表示误差(错了的)的数字,值越低越好);准确率(accuracy)

results = model.evaluate(test_data, test_labels, verbose=2)
results

25000/1 - 3s - loss: 0.3407 - accuracy: 0.8715



[0.330519519739151, 0.87152]

这种十分朴素的方法得到了约87%的准确率(accuracy)。若采用更好的方法,模型的准确率应当接近95%。

创建一个准确率(accuracy)和损失值(loss)随时间变化的图表

model.fit()返回一个History对象,该对象包含一个字典,其中包含训练阶段所发生的的一切事件:

history_dict = history.history
history_dict.keys()



dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])

有四个条目:在训练和验证期间,每一个条目对应一个监控指标。我们可以使用这些条目来绘制训练与验证过程的损失值(loss)和准确率(accuracy),以便进行比较。

import matplotlib.pyplot as plt

acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']  # val_accuracy:平均准确率?
loss = history_dict['loss']
val_loss = history_dict['val_loss']

epochs = range(1, len(acc) + 1)

# "bo"代表"蓝点"
plt.plot(epochs, loss, 'bo', label='Traing loss')
# b代表"蓝色实线"
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')  # 训练和测试(确认/验证)损失
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

在这里插入图片描述

plt.clf()  # 清楚数字
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

在这里插入图片描述

在该图中,点代表训练损失值(loss)与准确率(accuracy),实线代表验证损失值(loss)与准确率(accuracy)。

可以注意到,训练损失值随着epoch的增加而下降,而训练准确率(accuracy)随着epoch增加而上升。在使用梯度下降法优化模型时,这属于正常现象–理应在每次迭代中最小化期望值。

验证损失和准确率的变化情况并非如此,它们似乎在大约20个周期之后达到峰值。这是一种过拟合现象:模型在训练数据上的表现要优于在从未见过的数据上的表现。在此之后,模型会过度优化和学习特定于训练数据的表示,而无法泛化到测试数据。

对于这种特殊情况,我们可以在大约20个周期后停止训练,防止出现过拟合。稍后可以了解如何使用回调自动执行此操作

发布了21 篇原创文章 · 获赞 2 · 访问量 817

猜你喜欢

转载自blog.csdn.net/m0_44967199/article/details/103000449