Keras .fit和.fit_generator函数

参考博客 

https://blog.csdn.net/learning_tortosie/article/details/85243310

在本教程中,您将了解Keras .fit.fit_generator函数的工作原理,包括它们之间的差异。为了帮助您获得实践经验,我已经提供了一个完整的示例,向您展示如何从头开始实现Keras数据生成器。

Keras深度学习库包括三个独立的函数,可用于训练您自己的模型:

  • .fit
  • .fit_generator
  • .train_on_batch
  • 这三个函数基本上可以完成相同的任务,但他们如何去做这件事是非常不同的。

    让我们逐个探索这些函数,查看函数调用的示例,然后讨论它们彼此之间的差异。

调用.fit

model.fit(trainX, trainY, batch_size=32, epochs=50)

在这里可以看到提供的训练数据(trainX)和训练标签(trainY)。然后,我们指示Keras允许我们的模型训练50个epoch,同时batch size为32

.fit的调用在这里做出两个主要假设:

  • 我们的整个训练集可以放入RAM
  • 没有数据增强(即不需要Keras生成器)

我们的网络将在原始数据上训练。原始数据本身适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。此外,我们不会使用数据增强动态操纵训练数据。

对于小型,简单化的数据集,使用Keras的.fit函数是完全可以接受的。

这些数据集通常不是很具有挑战性,不需要任何数据增强。

但是,真实世界的数据集很少这么简单:

  • 真实世界的数据集通常太大而无法放入内存中
  • 它们也往往具有挑战性,要求我们执行数据增强以避免过拟合并增加我们的模型的泛化能力

在这些情况下,我们需要利用Keras的.fit_generator函数,函数原型为,

fit_generator(self, generator,            
                    steps_per_epoch=None, 
                    epochs=1, 
                    verbose=1, 
                    callbacks=None, 
                    validation_data=None, 
                    validation_steps=None,  
                    class_weight=None,
                    max_queue_size=10,   
                    workers=1, 
                    use_multiprocessing=False, 
                    shuffle=True, 
                    initial_epoch=0)

优点:通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。 

参数:

  • generator:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。
  • steps_per_epoch:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。
  • epochs:整数,在数据集上迭代的总数。
  • works:在使用基于进程的线程时,最多需要启动的进程数量。
  • use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。
# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32

# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
	width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
	horizontal_flip=True, fill_mode="nearest")

# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
	validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
	epochs=EPOCHS)

我们首先初始化将要训练的网络的epoch和batch size。

然后我们初始化aug,这是一个Keras ImageDataGenerator对象,用于图像的数据增强,随机平移,旋转,调整大小等。

执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。

但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。

根据提供给ImageDataGenerator的参数随机调整每批新数据。

因此,我们现在需要利用Keras的.fit_generator函数来训练我们的模型。

该函数本身是一个Python生成器

Keras在使用.fit_generator训练模型时的过程:

  • Keras调用提供给.fit_generator的生成器函数(在本例中为aug.flow
  • 生成器函数为.fit_generator函数生成一批大小为BS的数据
  • .fit_generator函数接受批量数据,执行反向传播,并更新模型中的权重
  • 重复该过程直到达到期望的epoch数量
  • 您会注意到我们现在需要在调用.fit_generator时提供steps_per_epoch参数(.fit方法没有这样的参数)。

    为什么我们需要steps_per_epoch

    请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。

    由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。

    因此,我们将训练数据的总数除以批量大小的结果作为steps_per_epoch的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。

  • 图像数据集作为CSV文件?

    在这里插入图片描述
    将在这里使用的数据集是Flowers-17数据集,它是17种不同花种的集合,每个类别有80个图像。

    我们的目标是培训Keras卷积神经网络,以正确分类每种花卉。

    但是,这个项目有点不同:

  • 不是使用存储在磁盘上的原始图像文件
  • 而是将整个图像数据集序列化为两个CSV文件(一个用于训练,一个用于评估)
  • 要构建每个CSV文件,我:

  • 循环输入数据集中的所有图像
  • 我们的目标是现在编写一个自定义Keras生成器来解析CSV文件,并为.fit_generator函数生成批量图像和标签。

  • 将它们调整为 64×64 像素
  • 将 64x64x3 = 12,288 个RGB像素的强度展平为单个列表
  • 在CSV文件中写入12,288个像素值和类标签(每行一个)

猜你喜欢

转载自blog.csdn.net/weixin_38145317/article/details/88709727
今日推荐