keras大数据量训练解决方法

当数据量很大时无法将数据全部读入内存运算,报错,可以改用批处理解决问题。

一.pandas读数据时可以设置成批量读入

二.使用keras中的train_on_batch方法

示例代码:

reader = pd.read_table('tmp.sv', sep=',', chunksize=4)

main_input = Input(shape=(50,16),name='main_input')
tmp = LSTM(32, return_sequences=True,dropout=0,name = 'lstm1')(main_input)
tmp = LSTM(32, return_sequences=False,name = 'lstm2')(tmp)
out = Dense(1, activation='sigmoid')(tmp)
model = Model(inputs=main_input, outputs=out)
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])

for chunk in reader:
    cost = model.train_on_batch(X_train, Y_train)

predicted = model.predict(x_ot,batch_size=32)
score, acc = model.evaluate(x_ot, y_ot,batch_size=32)

三.注意点:

1.模型训练完之后保存有两种方式:

方式一:

from keras.models import load_model

model.save('my_model.h5') 

model = load_model('my_model.h5')

这种方式保存模型全部的参数,包括指定的是cpu还是GPU,如果再GPU上训练但是拿到cpu上去用就会报错,所以用以下方法可以避免

model.save_weights('my_model_weights.h5')

model.load_weights('my_model_weights.h5')

详细参考:https://blog.csdn.net/u010159842/article/details/54407745

猜你喜欢

转载自blog.csdn.net/weixin_42247685/article/details/81904221