keras学习

断点训练方法:

在compile之后加入ModelCheckpoint:

cnn_net.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['acc'])
path="C:/Users/Administrator/Desktop/pytest/weights.{epoch:02d}-{val_loss:.2f}.hdf5"
checkpoint = ModelCheckpoint(path,
    monitor='loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)

其中命名为 weights.{epoch:02d}-{val_loss:.2f}.hdf5,模型被保存的的文件名就会有训练轮数和验证损失。

fit之前载入hdf5文件,就可以继续训练。

if os.path.exists("C:/Users/Administrator/Desktop/pytest/weights.05-0.03.hdf5"):
    cnn_net.load_weights("C:/Users/Administrator/Desktop/pytest/weights.05-0.03.hdf5")   
    print("checkpoint_loaded")
#训练模型
#validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等
hist=cnn_net.fit(x_train,y_train,batch_size=batch_size,epochs=15,verbose=1,validation_split=0.2,callbacks=[checkpoint],initial_epoch=5)#50

其中callbacks=[checkpoint]用来执行回调,initial_epoch=5设置起点轮数,设为5就会从第6轮开始。

猜你喜欢

转载自blog.csdn.net/qq_40250862/article/details/82729156