python神经网络(五)输入手写数字进行识别

一、断点续训
为防止突然断电、参数白跑的情况发生,在backward中加入类似于之前test中加载ckpt的操作,给所有w和b赋保存在ckpt中的值:

(1) 如果存储断点文件的目录文件夹中,包含有效断点状态文件,则返回该文件:

  • 参数说明
    checkpoint_dir: 表示存储断点文件的目录
    latest_filename: 断点文件的可选名称,默认为checkpoint
ckpt = tf.train.get_checkpoint_state(checkpoint_dir,\
 latest_filename = None)

**(2)**如果ckpt存在,且保存的模型在指定路径中存在

	if ckpt and ckpt.model_checkpoint_path: 

恢复当前会话,将ckpt中的值赋给 w 和 b
参数说明:sess:表示当前会话,之前保存的结果会被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,
因为有了位置会自动去查看checkpoint文件,看最新的模型叫什么

【注】内容来自mooc人工智能实践第六讲

猜你喜欢

转载自blog.csdn.net/petSym/article/details/84062615