一、断点续训
为防止突然断电、参数白跑的情况发生,在backward中加入类似于之前test中加载ckpt
的操作,给所有w和b赋保存在ckpt中的值:
(1) 如果存储断点文件的目录文件夹中,包含有效断点状态文件,则返回该文件:
- 参数说明
checkpoint_dir
: 表示存储断点文件的目录
latest_filename
: 断点文件的可选名称,默认为checkpoint
ckpt = tf.train.get_checkpoint_state(checkpoint_dir,\
latest_filename = None)
**(2)**如果ckpt存在,且保存的模型在指定路径中存在
if ckpt and ckpt.model_checkpoint_path:
恢复当前会话,将ckpt中的值赋给 w 和 b
参数说明:sess:表示当前会话,之前保存的结果会被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,
因为有了位置会自动去查看checkpoint文件,看最新的模型叫什么
【注】内容来自mooc人工智能实践第六讲