前言
在跑这个系列的代码的时候,发现数据太大了,9G呢~~~,所以引入数据生成器来减轻一下负担
一、DataGenerator是什么?
示例:数据生成器是深度学习训练时的一个技巧,那就是构造生成器generator
并且用Keras
的fit_generator
来批量生成数据,释放内存,该方法适合于大规模数据集的训练。一个DataGenerator
是keras
的Sequence
类的继承类,一般要包含__len__
,__getitem__
, on_epoch_end
等方法
二、使用步骤
def __init__
初始化def __len__(self):
返回生成器的长度,也就是总共分批生成数据的次数。def __getitem__(self, index):
该函数返回每次我们需要的经过处理的数据。def on_epoch_end(self):
该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。def __data_generation(self, list_IDs_temp):
主要是迭代读入图片,对图片预处理。
一般 def __getitem__(self, index):
会调用 def __data_generation(self, list_IDs_temp):
我这个程序里就联合在一起了
代码如下(示例):
class DataGenerator(Sequence):
def __init__(self, x, gender, y, batch_size=1, shuffle=False):
self.batch_size = batch_size
self.x = x
self.gender = gender
self.y = y
self.indexes = np.arange(len(self.x))
self.shuffle = shuffle
self.n = 0
self.max = self.__len__()
def __len__(self):
#计算每一个epoch的迭代次数
return math.ceil(len(self.x) / float(self.batch_size))
def __getitem__(self, index):
#生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
# 生成batch_size个索引
batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# 根据索引获取datas集合中的数据
batch_x = np.asarray([self.x[k] for k in batch_indexs], dtype=np.float32)
batch_x /= 255.
batch_gender = [self.gender[k] for k in batch_indexs]
batch_y = [self.y[k] for k in batch_indexs]
# print("batch_x:",np.array(batch_x).shape)
# print("batch_gender:",np.array(batch_gender).shape)
# print("batch_y:",np.array(batch_y).shape)
# 生成数据
return [np.array(batch_x),np.array(batch_gender)], np.array(batch_y)
def on_epoch_end(self):
#在每一次epoch结束是否需要进行一次随机,重新随机一下index
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __next__(self):
if self.n >= self.max:
self.n = 0
result = self.__getitem__(self.n)
self.n += 1
return result
总结
下次遇到数据集较大的模型就使用这个生成器吧,释放内存杠杠的。