10_ tf.keras Callbacks概述
1. tf.keras Callbacks是什么
Callbacks
的本质是一组函数对象,代码层面就是一个Python List
,在训练过程中的特定时期被执行,这些函数对象可以在训练过程中访问,保存或者修改训练中的参数,相当于在训练之前写好了几个锦囊,这些锦囊会在特定的时间被打开并且执行。用好Callbacks
训练过程将会是一个很愉快的过程。
典型应用场景:
- 解决训练之后的失控问题。
- 不知道训练多少轮可以得到想要的结果,这个时候可以通过
Callbacks
设置当模型不能进一步优化时停止训练。 - 通过
Tensorboard
等工具查看训练模型的内在状态和统计,全面直观的监控训练过程。
2. 使用Callbacks
- 1.实例化
Callback
。 - 2.以
Python List
形式传给model.fit()
方法的callbacks
参数。
例如实现Plateau
学习率策略:
reduce_lr = ReduceLROnPlateau(monitor='val_loss',factor=0.2,
patience=5,min_lr=0.001)
model.fit(x,y,callbacks=[reduce_lr])
3. tf.keras内置Callback函数
tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss',
verbose=0, save_best_only=False,
save_weights_only=False,
mode='auto', period=1)
主要参数:
filepath
:模型保存的路径.monitor
和save_best_only
:监控monitor
指定的指标,设置save_best_only
为True
时可以保存最好的模型,防止模型参数占用太多的硬盘容量.save_weights_only
:为True
时只保存权重,等于model.save_weights()
,为False
是保存权重和网络,等于model.save()
.
tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
patience=0, verbose=0,
mode='auto', baseline=None,
restore_best_weights=False)
主要参数:
min_delta
:在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。patience
: 在监测质量经过多少轮次没有进度时即停止。如果验证频率 (model.fit(validation_freq=5)) 大于 1,则可能不会在每个轮次都产生验证数.mode
: {auto, min, max} 其中之一,在 min 模式中,当被监测的数据停止下降训练就会停止;在 max 模式中当被监测的数据停止上升,训练就会停止;在 auto 模式中方向会自动从被监测的数据的名字中判断出来.baseline
: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止.restore_best_weights
: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重.
扫描二维码关注公众号,回复:
12436397 查看本文章

tf.keras.callbacks.RemoteMonitor(root='http://localhost:9000',
path='/publish/epoch/end/',
field='data', headers=None,
send_as_json=False)
主要参数
root
: 目标服务器的根地址.path
: 相对于 root 的路径,事件数据被送达的地址.field
: JSON ,数据被保存的领域.headers
: 可选自定义的 HTTP 的头字段.send_as_json
: 请求是否应该以 application/json 格式发送.
tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
主要参数
schedule
: 一个函数,接受轮索引数作为输入(整数,从 0 开始迭代) 然后返回一个学习速率作为输出(浮点数).
tf.keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0,
batch_size=32, write_graph=True,
write_grads=False, write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None, embeddings_data=None,
update_freq='epoch')
tf.keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None,
on_batch_begin=None, on_batch_end=None,
on_train_begin=None, on_train_end=None)