tf 加速训练,节省内存 fit_generator

如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。

直接给出关键代码,数据处理和组网部分略过

处理后的数据格式如下:

在这里插入图片描述

  • 第一个列表是特征,[39, 500000]
  • 第二个列表是label

fit_generator

batch_size = 2048
model.fit_generator(
    GeneratorRandomPatchs(train_x, train_y, batch_size),
    validation_data=(val_x, val_y),
    steps_per_epoch=len(train_data) // batch_size,
    epochs=100,
    verbose=1
)
def GeneratorRandomPatchs(train_x, train_y, batch_size):
    totl, col = np.array(train_x).shape  # (39, 500000)  特征数、样本数
    # 保证 steps_per_epoch * epoch 批次的数据够
    while True:
        for index in range(0, col, batch_size):
            xs, ys = [], []
            for t in range(totl):
                xs.append(train_x[t][index: index + batch_size])
            ys.append(train_y[0][index: index + batch_size])
            # print(np.array(xs).shape, np.array(ys).shape)
            yield (xs, ys)

fit_generator 需要传递一个迭代器,如上述例子:GeneratorRandomPatchs,通过yield返回训练数据

  • batch_size:批处理大小,就是每次入模的样本数。
  • steps_per_epoch:每个epoch要处理的批数。比如训练数据50W,batch_size是2048,那么一个epoch的批数就是244。

在这里插入图片描述

其它参数解释
在这里插入图片描述
在这里插入图片描述


参考自:

https://blog.csdn.net/zhangpeterx/article/details/90900118

https://blog.csdn.net/qq_39783265/article/details/106752903

扫描二维码关注公众号,回复: 13197402 查看本文章

https://www.jb51.net/article/188905.htm

https://my.oschina.net/u/4329662/blog/3639783

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/121335701