【深度学习】Tensorflow模型保存与恢复

tf.train.Saver()的定义与使用

Saver对象:用于在tf中保存,恢复Session
定义

model_path="/tmp/model.ckpt"
saver=tf.train.Saver()

Saver保存操作:saver.save(sess,model_path)

save_path=saver.save(sess,model_path)

Saver恢复操作:saver.restore(sess,save_path)

saver.restore(sess,model_path)

注意事项:
1.tf.train.Saver()定义在Session之前
2.saver.save()和saver.restore()都在Session里进行

tf.train.Saver()使用代码示例

# -*- coding: utf-8 -*-
"""
Created on Wed Jul 19 22:59:41 2017
@author: ZMJ
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
print "Package Loaded"

np.random.seed(1)
def f(x,weight,bias):
  return x*weight+bias

Wref=0.7
Bref=-0.1
n=20
noise_var=0.05
train_X=np.random.random((n,1))
ref_Y=f(train_X,Wref,Bref)
train_Y=ref_Y+noise_var*np.random.randn(n,1)
model_path="/tmp/linear_model.ckpt"

lr=0.01
epochs=5000
display_step=250
n_samples=train_X.size

plt.subplot(121)
plt.axis("equal")
plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")
plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")
plt.title("Sactter Plot of Data")
plt.legend(loc="lower right")

weight=tf.Variable(np.random.randn(),name="weight")
bias=tf.Variable(np.random.randn(),name="bias")
x=tf.placeholder(tf.float32,shape=[n_samples,1],name="input")
y=tf.placeholder(tf.float32,shape=[n_samples,1],name="output")

"""
Model
"""
pred=x*weight+bias
cost=tf.reduce_mean(tf.pow(pred-y,2))
optimizer=tf.train.GradientDescentOptimizer(lr).minimize(cost)
init=tf.global_variables_initializer()

"""
Saver Defination
"""
saver=tf.train.Saver()

"""
Run Model in First Session
"""
with tf.Session() as sess:
  sess.run(init)
  for epoch in range(500):
    l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})
    if epoch%display_step==0:
      c=sess.run(cost,feed_dict={x:train_X,y:train_Y})
      print "Epoch %s .Cost=%s"%(epoch,c)
  print "First Session Compelted!"

  save_path=saver.save(sess,model_path)
  print "Save Completed,Save Path = %s"%save_path

"""
Run Model in Second Session
"""
with tf.Session() as sess:
  #sess.run(init)
  saver.restore(sess,model_path)
  print "Model Restored From %s"%model_path

  for epoch in range(epochs-500):
    l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})
    if epoch%display_step==0:
      c=sess.run(cost,feed_dict={x:train_X,y:train_Y})
      print "Epoch %s .Cost=%s"%(epoch,c)

  print "Second Session Compelted!"
  save_path=saver.save(sess,model_path)
  print "Save Completed,Save Path = %s"%save_path

  Wop=sess.run(weight)
  Bop=sess.run(bias)
  fop=f(train_X,Wop,Bop)      
  plt.subplot(122)
  plt.plot()
  plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")
  plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")
  plt.plot(train_X[:,0],fop[:,0],"k-",label="Predicted Line")
  plt.title("Predicted Line")
  plt.legend(loc="lower right")
  plt.show()

打印的日志:

Epoch 0 .Cost=0.269742
Epoch 250 .Cost=0.0531464
First Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt
Model Restored From /tmp/linear_model.ckpt
Epoch 0 .Cost=0.0323754
Epoch 250 .Cost=0.019944
Epoch 500 .Cost=0.0125031
Epoch 750 .Cost=0.00804937
Epoch 1000 .Cost=0.00538358
Epoch 1250 .Cost=0.00378797
Epoch 1500 .Cost=0.00283292
Epoch 1750 .Cost=0.00226127
Epoch 2000 .Cost=0.00191911
Epoch 2250 .Cost=0.00171431
Epoch 2500 .Cost=0.00159173
Epoch 2750 .Cost=0.00151836
Epoch 3000 .Cost=0.00147444
Epoch 3250 .Cost=0.00144815
Epoch 3500 .Cost=0.00143242
Epoch 3750 .Cost=0.001423
Epoch 4000 .Cost=0.00141736
Epoch 4250 .Cost=0.00141399
Second Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt

这里写图片描述

猜你喜欢

转载自blog.csdn.net/qq_29340857/article/details/75463735