1、saver.restore()
用 saver.restore()加载 模型 之前,首先要 定义 模型 的 计算图,具体操作如下:
#重新定义 计算图
def define_graph(input):
graph_define
return computed_tensor #返回需要计算的张量
with tf.Session() as sess:
#input = tf.placeholder(tf.float,shape,name="input")
computed_tensor = define_graph(input)
#加载模型
saver = tf.train.Saver()
saver.restore(sess, model.ckpt)
#计算 张量
sess.run(computed_tensor, feed_dict={
input:input_data})
2、saver.restore() + tf.train.import_meta_graph()
#利用 tf.train.import_meta_graph()载入计算图
saver = tf.train.import_meta_graph('model.ckpt.meta')
with tf.Session() as sess:
#载入模型
saver.restore(sess,'model.ckpt')
#载入 要计算的张量名
input, computed_tensor = tf.get_default_graph().get_tensor_by_name(['input', 'computed_tensor:0'])
# 计算张量
sess.run(computed_tensor, feed_dict = {
input:input_data})
3、gfile.GFile()
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
#保存模型,及要计算的 tensor
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['input', 'computed_tensor'])
with gfile.GFile('model.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
#加载模型,以及要计算的 tensor
with tf.Session() as sess:
with gfile.FastGFile('model.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
input, computed_tensor = tf.import_graph_def(graph_def,return_elements=["input", "computed_tensor:0"])
sess.run(computed_tensor, feed_dict = {
input:input_data})