10_ tf.keras Callbacks概述

1. tf.keras Callbacks是什么

Keras官方文档:回调Callbacks

Callbacks的本质是一组函数对象,代码层面就是一个Python List在训练过程中的特定时期被执行,这些函数对象可以在训练过程中访问,保存或者修改训练中的参数,相当于在训练之前写好了几个锦囊,这些锦囊会在特定的时间被打开并且执行。用好Callbacks训练过程将会是一个很愉快的过程。

典型应用场景:

  • 解决训练之后的失控问题。
  • 不知道训练多少轮可以得到想要的结果,这个时候可以通过Callbacks设置当模型不能进一步优化时停止训练。
  • 通过Tensorboard等工具查看训练模型的内在状态和统计,全面直观的监控训练过程。

2. 使用Callbacks

  • 1.实例化Callback
  • 2.以Python List形式传给model.fit()方法的callbacks参数。

例如实现Plateau学习率策略:

reduce_lr = ReduceLROnPlateau(monitor='val_loss',factor=0.2,
								patience=5,min_lr=0.001)
model.fit(x,y,callbacks=[reduce_lr])

3. tf.keras内置Callback函数

动态模型保存 ModelCheckpoint

tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', 
								verbose=0, save_best_only=False, 
								save_weights_only=False, 
								mode='auto', period=1)

主要参数:

  • filepath:模型保存的路径.
  • monitorsave_best_only:监控monitor指定的指标,设置save_best_onlyTrue时可以保存最好的模型,防止模型参数占用太多的硬盘容量.
  • save_weights_only:为True时只保存权重,等于model.save_weights(),为False是保存权重和网络,等于model.save().

动态训练终止 EarlyStopping

tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, 
								patience=0, verbose=0, 
								mode='auto', baseline=None, 
								restore_best_weights=False)

主要参数:

  • min_delta:在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
  • patience: 在监测质量经过多少轮次没有进度时即停止。如果验证频率 (model.fit(validation_freq=5)) 大于 1,则可能不会在每个轮次都产生验证数.
  • mode: {auto, min, max} 其中之一,在 min 模式中,当被监测的数据停止下降训练就会停止;在 max 模式中当被监测的数据停止上升,训练就会停止;在 auto 模式中方向会自动从被监测的数据的名字中判断出来.
  • baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止.
  • restore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重.

远程事件监控 RemoteMonitor

tf.keras.callbacks.RemoteMonitor(root='http://localhost:9000',
								 path='/publish/epoch/end/',
								 field='data', headers=None, 
								 send_as_json=False)

主要参数

  • root: 目标服务器的根地址.
  • path: 相对于 root 的路径,事件数据被送达的地址.
  • field: JSON ,数据被保存的领域.
  • headers: 可选自定义的 HTTP 的头字段.
  • send_as_json: 请求是否应该以 application/json 格式发送.

自定义动态学习率 ReduceLROnPlateau

tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)

主要参数

  • schedule: 一个函数,接受轮索引数作为输入(整数,从 0 开始迭代) 然后返回一个学习速率作为输出(浮点数).

数据可视化 Tensorboard

tf.keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, 
									batch_size=32, write_graph=True,
									write_grads=False, write_images=False,
									embeddings_freq=0, 
									embeddings_layer_names=None,
								    embeddings_metadata=None, embeddings_data=None, 
								    update_freq='epoch')

简单自定义Callback LambdaCallback

tf.keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, 
										on_batch_begin=None, on_batch_end=None,
									    on_train_begin=None, on_train_end=None)

猜你喜欢

转载自blog.csdn.net/PecoHe/article/details/105100571