使用model.fit_generator方法进行训练(自己的训练集-多分类)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xfjjs_net/article/details/84798045

我们在使用model.fit()进行训练的时候, 在这之前你肯定会有训练集的x_img_train,y_label_train两个参数。

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

但是当我们使用model.fit_generator()的时候,它的方法是这样的:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

可以看到它要求传入的参数是一个generator.官网说的很清楚,(不清楚的可以看官网)这里的generator是一个生成器,主要是训练自己的数据,并且数据非常多的时候可以不用把数据全部加载进内存,而是用生成器自己一点点读取。大大提高的运行效率。

下面是这个生成器的生成方法:

#这是训练集的生成器
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

## 训练图片生成器
train_generator = train_datagen.flow_from_directory(
    train_data_dir,#训练样本地址
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical') #多分类

test_datagen = ImageDataGenerator(rescale=1. / 255)

##验证集的生成器
validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,#验证样本地址
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False) #多分类

好了,有了这个train_generator生成器我们就可以入入fit_generator(...)里面进行训练了。

对了,这里说明下train_data_dir / validation_data_dir 是我本机的训练集与验证集的地址。

目录结构形似:

'''
data/train/
          1/
             001.jpg
             002.jpg
             ...
          2/
            001.jpg
            002.jpg
            ...

data/validation/
                1/
                    001.jpg
                    002.jpg
                    ...
                2/
                    001.jpg
                    002.jpg
                    ...
            
'''

猜你喜欢

转载自blog.csdn.net/xfjjs_net/article/details/84798045