tensorflow模型的保存和载入
最近在学习tensorflow,需要用到模型的保存和载入。保存比较简单,但是在载入后调用的时候,觉得有点麻烦。不是很清楚。多方了解,写出了下面的代码,希望能满足各位的需要
# -*-coding:utf-8 -*-
import tensorflow as tf
'''
图的持久化
对需要操作的部分进行命名
然后,在载入图之后,run就可以了
'''
def interface(x, y):
with tf.variable_scope("test_1"):
a = tf.Variable(tf.constant(1.0, shape=[1], name="a"))
test1 = tf.add(x, a, name="test1")
y_ = x + y + a
return y_
def train():
g1 = tf.Graph()
with g1.as_default():
# 将图持久化
# v1 = tf.Variable(tf.constant(1.0, shape=[1], name="v1"))
# v2 = tf.Variable(tf.constant(1.0, shape=[1], name="v2"))
x = tf.placeholder(tf.float32, shape=[2], name="v1")
y = tf.placeholder(tf.float32, shape=[2], name="v2")
y_ = interface(x, y)
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
_y = sess.run([y_], feed_dict={x: [3.0, 2.0], y: [4.0, 3.0]})
print _y
saver.save(sess, "./model/model.ckpt")
def restore():
saver = tf.train.import_meta_graph("./model/model.ckpt.meta")
test1 = tf.get_default_graph().get_tensor_by_name("test_1/test1:0")
with tf.Session() as sess:
saver.restore(sess, "./model/model.ckpt")
_test1 = sess.run([test1],
feed_dict={tf.get_default_graph().get_tensor_by_name("v1:0"): [30.0],
tf.get_default_graph().get_tensor_by_name("v2:0"): [40.0]})
print _test1[0].shape
if __name__ == "__main__":
train()
restore()