Keras使用fit_generator训练超大数据集

对于小规模的数据集我们可以将其一次性读入内存(CPU)中然后再分batch让GPU去训练,只要简单地使用fit函数即可;然而当数据集规模超大时,一次性读入所有数据会使得内存溢出(与GPU无关,GPU是显存,batch_size过大才会导致显存爆炸),我们需要使用fit_generator来训练

如下,

fit_generator(generator, steps_per_epoch=None,
              epochs=1, verbose=1,
              callbacks=None, validation_data=None,
              validation_steps=None,  class_weight=None,
              max_queue_size=10, workers=1,
              use_multiprocessing=False, shuffle=True, initial_epoch=0)

它具有很多参数,其中比较重要的参数是generator,我们可以使用python的生成器,也可以使用keras.utils中的Sequence,个人比较推荐后者,它需要重写两个函数:

  • len()
  • getitem()

其实也很简单,以下举个简单的例子你便能学会写自己的数据自动生成器。

代码如下,接下来一个个函数解释。

import numpy
import matplotlib.image as mpimg  # read image
from keras.utils import Sequence, to_categorical

class DataGenerator(Sequence):

    def __init__(self, files, batch_size=1, shuffle=True):
        """
        # Arguments
        ---
            files: filename.
            batch_size: . """

        self.batch_size = batch_size
        self.files = files
        self.indexes = numpy.arange(len(self.files))
        self.shuffle = shuffle

    def __len__(self):
        """return: steps num of one epoch. """
        return len(self.files) // self.batch_size

    def __getitem__(self, index):
        """Gets the `index-th` batch.
        ---
        # Arguments
            index: position of the batch in the Sequence.
        # Returns
            A batch data. """

        # get batch data inds.
        batch_inds = self.indexes[index *
                                  self.batch_size:(index+1)*self.batch_size]
        # get batch data file name.
        batch_files = [self.files[k] for k in batch_inds]

        # read batch data
        batch_images, batch_labels = self._read_data(batch_files)

        return batch_images, batch_labels

    def on_epoch_end(self):
        """shuffle data after one epoch. """
        if self.shuffle == True:
            numpy.random.shuffle(self.indexes)

    def _read_data(self, batch_files):
        """Read a batch data.
        ---
        # Arguments
            batch_files: the file of batch data.

        # Returns
            images: (batch_size, (image_shape)).
            labels: (batch_size, (label_shape)). """

        images = []
        labels = []

        for file in batch_files:
            image = mpimg.imread('data/Images/'+file+'.jpg')
            images.append(image)
            lable = numpy.loadtxt('data/labels/'+file+'.arr', dtype=int)
            labels.append(lable)  # to one hot

        return numpy.array(images), numpy.array(labels)

init

首先定义自己的数据生成器 DataGenerator 继承Sequence类,初始化参数

class DataGenerator(Sequence):
    def __init__(self, files, batch_size=1, shuffle=True):
        """
        # Arguments
        ---
            files: filename.
            batch_size: . """
        self.batch_size = batch_size
        self.files = files
        self.indexes = numpy.arange(len(self.files))
        self.shuffle = shuffle
  • files:所有训练样本的文件路径
  • batch_size:一个batch的大小

例如我做图像分割任务,图片大概是这样的
在这里插入图片描述
label是像素标签矩阵,与上面的每张图片对应
在这里插入图片描述
因此我只需要除了文件后缀之外的字符串即可,当读取对应的图片或label时我再在后面添加对应的后缀。所以传入的 files 大概是这样的:
在这里插入图片描述
并且创建一个index变量用来划分数据集。

len

len函数计算每一个epoch的步长,即每个epoch有多少个batch要训练

def __len__(self):
    """return: steps num of one epoch. """
    return len(self.files) // self.batch_size

getitem函数要求返回一个batch的data,包括输入数据和标签,参数是index(训练的时候会自动调用),代表第index个batch,通过这个index我们去计算这个batch的数据的下标(batch_inds),然后提取出它对应的文件名字(batch_files),把这个batch_files传入给_read_data函数让它取读取这个batch的image和label

def __getitem__(self, index):
    """Gets the `index-th` batch.
    ---
    # Arguments
        index: position of the batch in the Sequence.
    # Returns
        A batch data. """

    # get batch data inds.
    batch_inds = self.indexes[index *
                                self.batch_size:(index+1)*self.batch_size]
    # get batch data file name.
    batch_files = [self.files[k] for k in batch_inds]

    # read batch data
    batch_images, batch_labels = self._read_data(batch_files)

    return batch_images, batch_labels

_read_data

这个函数就是根据文件名去读取自己对应任务的数据了。

def _read_data(self, batch_files):
    """Read a batch data.
    ---
    # Arguments
        batch_files: the file of batch data.

    # Returns
        images: (batch_size, (image_shape)).
        labels: (batch_size, (label_shape)). """

    images = []
    labels = []

    for file in batch_files:
        image = mpimg.imread('data/Images/'+file+'.jpg')
        images.append(image)
        lable = numpy.loadtxt('data/labels/'+file+'.arr', dtype=int)
        labels.append(lable)  # to one hot

    return numpy.array(images), numpy.array(labels)

on_epoch_end

这个函数是每个epoch结束后调用,是否将数据集打乱。

def on_epoch_end(self):
    """shuffle data after one epoch. """
    if self.shuffle == True:
        numpy.random.shuffle(self.indexes)

以上即为构造数据生成器的过程,在构造模型时,可参考以下例子。构造FCN32模型,且用两个GPU来训练,此时batchsize的大小如果是1会报错,因为另一个GPU就没有训练数据了,所以batchsize至少是2,如果是单个GPU训练则不影响。

def get_generator(num_classes, batch_size=2, preprocess=True, shuffle=True, train_ratio=0.8):
    """Get data generator for training and test files.
    ---
    # Arguments
        num_classes: .
        preprocess: preprocess image data.
        shuffle: shuffle data after epoch.
        train_ratio: split data.

    # Returns
        generator: data generator for fit_generator.
        test_files: test model after training. """

    images_files = os.listdir('data/Images')
    allfiles = []
    for file in images_files:
        allfiles.append(file.split('.')[0])  # get image name
    N = len(allfiles)
    train_files = allfiles[:int(N*train_ratio)]
    test_files = allfiles[int(N*train_ratio):]
    generator = DataGenerator(train_files, num_classes,
                              batch_size, preprocess, shuffle)
    return generator, test_files

num_classes = 34
# Get generator.
generator, test_files = get_generator(num_classes,
                                      batch_size=2,
                                      preprocess=True,
                                      shuffle=True,
                                      train_ratio=0.8)
# Build model.
model = FCN32.build(height=256, width=256, num_classes=num_classes)
model = multi_gpu_model(model, gpus=2)  # 两个GPU, batch_size至少应该是2
model.summary()
model.compile(optimizer=Adam(0.00001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# fit model.
model.fit_generator(generator, epochs=20)
发布了83 篇原创文章 · 获赞 4 · 访问量 5364

猜你喜欢

转载自blog.csdn.net/weixin_43486780/article/details/105365930