TensorFlow基础学习(一):保存和载入模型

一、保存和载入模型
1、保存模型
建立一个saver,然后在session中通过saver的save即可将模型保存起来。

saver=tf.train.Saver()        #生成saver
with tf.Session() as sess:
      sess.run(tf.global_variable_initializer())                #模型初始化
      #然后将数据丢入模型进行训练
      #训练完后,使用saver.save保存
      saver.save(sess,"save_path/file_name")    #如果file_name不存在,会自动创建

2、载入模型
将模型保存好以后,载入也方便。在session中通过调用saver的restore()函数,从制定的路径找到模型文件,并覆盖到相关参数中。

saver=tf.train.Sacer()

with tf.Session() as sess:
#参数可以进行初始化,也可以不进行初始化。
     sess.run(tf.global_variables_initializer())
     saver.restore(sess,"save_path/file_name")  #会将已经保存的变量值restore到变量中

二、分析模型内容,演示模型的其他保存方法
1、模型内容
通过编写代码将模型的内容打印出来,看看保存了哪些东西

from tensorflow .python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
savedir=" "   #保存路径
print_tensors_in_checkpoint_file(savedir+" linermodek.ckpt",None,True)

2、保存模型的其他方法
tf.train.Saver函数里面可以放参数来实现更高级的功能,可以指定存储变量名字与变量的对应关系。

saver=tf.train.Saver({'weight' : W, 'bias' : b})

三、检查点(Checkpoint)
在训练中保存模型,称之为保存检查点
使用tf.train.Saver(max_to_keep=1)代码创建saver时传入的参数max_to_keep=1表示在迭代过程中只保存一个文件。
四、模型操作相关函数总结
函数 说明
tf.train.Saver() 创建存储器Saver
tf.train.Saver.save() 保存
tf.train.Saver.restore() 恢复
tf.train.Saver.last_checkpoints 列出最近未删除的checkpoint文件名
tf.train.Saver.set_last_checkpoint() 设置checkpoint文件名列表

参考书:《深度学习之TensorFlow入门、原理与进阶实战》

猜你喜欢

转载自blog.csdn.net/weixin_43152685/article/details/91047703