tf.train.Saver()的定义与使用
Saver对象:用于在tf中保存,恢复Session
定义
model_path="/tmp/model.ckpt"
saver=tf.train.Saver()
Saver保存操作:saver.save(sess,model_path)
save_path=saver.save(sess,model_path)
Saver恢复操作:saver.restore(sess,save_path)
saver.restore(sess,model_path)
注意事项:
1.tf.train.Saver()定义在Session之前
2.saver.save()和saver.restore()都在Session里进行
tf.train.Saver()使用代码示例
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 19 22:59:41 2017
@author: ZMJ
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
print "Package Loaded"
np.random.seed(1)
def f(x,weight,bias):
return x*weight+bias
Wref=0.7
Bref=-0.1
n=20
noise_var=0.05
train_X=np.random.random((n,1))
ref_Y=f(train_X,Wref,Bref)
train_Y=ref_Y+noise_var*np.random.randn(n,1)
model_path="/tmp/linear_model.ckpt"
lr=0.01
epochs=5000
display_step=250
n_samples=train_X.size
plt.subplot(121)
plt.axis("equal")
plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")
plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")
plt.title("Sactter Plot of Data")
plt.legend(loc="lower right")
weight=tf.Variable(np.random.randn(),name="weight")
bias=tf.Variable(np.random.randn(),name="bias")
x=tf.placeholder(tf.float32,shape=[n_samples,1],name="input")
y=tf.placeholder(tf.float32,shape=[n_samples,1],name="output")
"""
Model
"""
pred=x*weight+bias
cost=tf.reduce_mean(tf.pow(pred-y,2))
optimizer=tf.train.GradientDescentOptimizer(lr).minimize(cost)
init=tf.global_variables_initializer()
"""
Saver Defination
"""
saver=tf.train.Saver()
"""
Run Model in First Session
"""
with tf.Session() as sess:
sess.run(init)
for epoch in range(500):
l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})
if epoch%display_step==0:
c=sess.run(cost,feed_dict={x:train_X,y:train_Y})
print "Epoch %s .Cost=%s"%(epoch,c)
print "First Session Compelted!"
save_path=saver.save(sess,model_path)
print "Save Completed,Save Path = %s"%save_path
"""
Run Model in Second Session
"""
with tf.Session() as sess:
#sess.run(init)
saver.restore(sess,model_path)
print "Model Restored From %s"%model_path
for epoch in range(epochs-500):
l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})
if epoch%display_step==0:
c=sess.run(cost,feed_dict={x:train_X,y:train_Y})
print "Epoch %s .Cost=%s"%(epoch,c)
print "Second Session Compelted!"
save_path=saver.save(sess,model_path)
print "Save Completed,Save Path = %s"%save_path
Wop=sess.run(weight)
Bop=sess.run(bias)
fop=f(train_X,Wop,Bop)
plt.subplot(122)
plt.plot()
plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")
plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")
plt.plot(train_X[:,0],fop[:,0],"k-",label="Predicted Line")
plt.title("Predicted Line")
plt.legend(loc="lower right")
plt.show()
打印的日志:
Epoch 0 .Cost=0.269742
Epoch 250 .Cost=0.0531464
First Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt
Model Restored From /tmp/linear_model.ckpt
Epoch 0 .Cost=0.0323754
Epoch 250 .Cost=0.019944
Epoch 500 .Cost=0.0125031
Epoch 750 .Cost=0.00804937
Epoch 1000 .Cost=0.00538358
Epoch 1250 .Cost=0.00378797
Epoch 1500 .Cost=0.00283292
Epoch 1750 .Cost=0.00226127
Epoch 2000 .Cost=0.00191911
Epoch 2250 .Cost=0.00171431
Epoch 2500 .Cost=0.00159173
Epoch 2750 .Cost=0.00151836
Epoch 3000 .Cost=0.00147444
Epoch 3250 .Cost=0.00144815
Epoch 3500 .Cost=0.00143242
Epoch 3750 .Cost=0.001423
Epoch 4000 .Cost=0.00141736
Epoch 4250 .Cost=0.00141399
Second Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt