Keras中的fit和fit_generator

Keras是超级无敌好入手的AI框架之一了,极其人性化的设计受到了本人的吹爆。然而,keras中比较难理解的地方还是存在的,比如说这个fit_generator

在模型搭建完compile以后,一行"model.fit_generator(xxx)"就可以完成训练。真正让服务器开始忙的就是这一行代码。

keras给模型喂入数据的函数有fitfit_generator

我们知道fit的用法,十分简便,把x和y封装成ndarray数组就可以使用了,我们看一个简单的例子:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils

# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)

# 搭建网络
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=y)
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

# 喂入数据
print(type(x_train))  # numpy.ndarray
model.fit(x_train, y_train, batch_size=32, epochs=10)

用宇宙最通用数据集MNIST来举例model.fit的用法:x_train和y_train的type都是多维numpy数组,所以可以直接按进fit的括号里。batchsize和epochs额外设置。

所以,我们训练Keras模型的时候,可以直接把所有训练数据封装成多维数组,然后直接用model.fit()将数据灌进去训练。这真的是通用的方法吗? 

显然不是,我们能这么做是因为咱家用的MNIST数据集,掰开了揉也才十几兆。试想,如果训练集是COCO,那你还会把COCO训练集全部包装成一个数组? 你电脑有那么大内存吗?!

对于这种情况,我们就不得不用到fit_generator了。接下来,正式介绍keras中的fit_generator


fit_genrator

keras中的fit_generator是keras用来为训练模型生成批次数据的工具。它的输入可以是一个python生成器,也可以是一个Keras Sequence。

用生成器喂数据的优点,除了不依赖大内存以外,还具有非常多的优点。比如,它可以并行预处理数据,咱们的模型是运行在GPU上的,generator运行在CPU上,所以GPU跑模型,CPU预处理数据,这样就可以达到很高的时效性。

1. 用python生成器作为输入

python生成器的核心就是一个yield关键字,具体定义可以参考这个1这个2, 以及我之前写的一篇《yield关键字》。我们来看一个来自keras官方的例子:

def generate_arrays_from_file(path):
            while True:
                with open(path) as f:
                    for line in f:
                        # create numpy arrays of input data
                        # and labels, from each line in the file
                        x1, x2, y = process_line(line)
                        yield ({'input_1': x1, 'input_2': x2}, {'output': y})

model.fit_generator(generate_arrays_from_file('./my_folder'),
                            steps_per_epoch=10000, epochs=10)

我们先用一个函数generate_arrays_from_file作为一个生成器,这个生成器可以一点点地在目录文件夹里抓数据,然后包装数据喂给模型。缺点是不太方便设置batch_size。

博主推荐指数:★★★★☆

2. 用keras Sequence实例作为输入

keras.utils.data_utils子模块中提供了Sequence类,可以作为数据序列使用。

用法:1. 新建一个基于Sequence的子类,比如MnistSequence,子类必须继承'__getitem__'和'__len__'这两个类方法。(因为这两个类方法是fit_generator读取数据的接口)

2. 实例化新建子类,再输入到fit_generator中。

我们依旧来看一段代码来理解:

class MnistSequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return np.array([
                resize(imread(file_name), (200, 200))
                for file_name in batch_x]), np.array(batch_y)

model.fit_generator(MnistSequence(training_x, training_y, 32), epochs=10000)

用Sequence子类可以实现将数据整理成数据序列,排序喂入模型。然而,上面的例子不足以表明用这个方法的优越性,看起来好像还是需要把training_x和training_y整理成多维数组才能使用。其实不然,我们可以继续在子类里添加读取文件夹的函数分批次获取数据,然后可以实现如python生成器一样的效果。同时,还可以添加任意预处理函数。

博主推荐指数:★★★★★

猜你喜欢

转载自blog.csdn.net/leviopku/article/details/87912097