keras:5)fit_generator

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jiangpeng59/article/details/79515680

1.fit_generator
fit_generator函数参数描述可以参看官方文档,这里说下比较常用的几个参数:
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

generator和python的generator没有本质的区别,都是yield返回所需的数据,不过这里是在一个无限循环之中

steps_per_epoch*: Integer. Total number of steps (batches of samples) to yield from generator before declaring one epoch finished and starting the next epoch. It should typically be equal to the number of samples of your dataset divided by the batch*

steps_per_epoch就是一个epoch执行多少次batch,最后一句话也很明确,steps_per_epoch=train_size//batch_size

简单来说,该函数就是解决训练集过大,无法一次性放入内容,每个batch的数据都从磁盘上获取

下面是keras官方的一个Demo

def generate_arrays_from_file(path):
    while True:
        with open(path) as f:
            for line in f:
                # create numpy arrays of input data
                # and labels, from each line in the file
                x1, x2, y = process_line(line)
                yield ({'input_1': x1, 'input_2': x2}, {'output': y})

model.fit_generator(generate_arrays_from_file('/my_file.txt'),
                    steps_per_epoch=10000, epochs=10)

该Demo有个问题就是没有体现出batch_size(暂时认为process_line只生产了一个数据),我们使用下面的代码,让generate_arrays_from_file每次都生产一个batch.

def generate_arrays_from_file(path,batch_size):
    list_x=[]
    list_y=[] #保存返回的batch数据
    count=0
    while True:
        with open(path) as f:
            for line in f:
                # process_line只返回一条记录x和y
                x, y = process_line(line)
                list_x.append(x)
                list_y.append(y)
                count+=1
                if count>=batch_size: #数据记录达到batch_size才返回
                    yield (list_x,list_y)
                    count=0
                    list_x=[];list_y=[]

#假定batch_size和train_size的大小
batch_size,train_size=32,1314
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
                    steps_per_epoch=train_size//batch_size, epochs=10)

猜你喜欢

转载自blog.csdn.net/jiangpeng59/article/details/79515680