以往总是用saver保存checkpoint形式来保存训练结果,发现做预测时需要重构原来的网络结构与参数,现在用pb方式保存,好多了。且多次预测时不需要重置网络了。代码如下
import tensorflow as tf import numpy as np import os import test1 import pickle from tensorflow.python.framework.graph_util import convert_variables_to_constants from tensorflow.python.platform import gfile tf.app.flags.DEFINE_string('checkpoints_dir3', os.path.abspath('./checkpoints/travel_gan/'), 'checkpoints save path.') tf.app.flags.DEFINE_string('model_prefix3', 'travel_gan', 'model save prefix.') FLAGS = tf.app.flags.FLAGS ################################################################################################### #训练过程 def train(): train_steps = 100 checkpoint_steps = 50 # checkpoint_dir = 'C:/Users/sd\Desktop/travel/checkpoints/travel_gan' x = tf.placeholder(tf.float32, shape=[None, 1],name='input_0')#第一个输入变量 x2 = tf.placeholder(tf.float32, shape=[None, 1],name='input_1')#第2个输入 x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) x2_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) y = 4 * x *x2+ 4 w = tf.Variable(tf.random_normal([1], -1, 1),name="w") b = tf.Variable(tf.zeros([1]),name="b") y_predict = tf.add(w*x*x2,b,name='out_0') y_predict2 = tf.add(w*x*x2,b,name='out_1') loss = tf.reduce_mean(tf.square(y - y_predict),name="out_loss") optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for i in range(0,train_steps): sess.run(train, feed_dict={x: x_data,x2:x2_data}) #graph = convert_variables_to_constants(sess, sess.graph_def, ["out_0","out_1"])#保存了最后一次的参数和模型 #graph = (sess, sess.graph_def, ["out_0","out_1"])#保存了最后一次的参数和模型 #tf.train.write_graph(graph, '.', 'graph.pb', as_text=False) # os.system("rm -rf /tmp/load") graph = convert_variables_to_constants(sess, sess.graph_def, ["out_0","out_1","out_loss"]) tf.train.write_graph(graph, '.',"graph.pb", False) ############################################################################################################### #预测过程,直接调用前面训练过程保存的模型,并且不用重构网络 def predict(): with tf.Session() as sess: with open('./graph.pb', 'rb') as f: # with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) a=1.0 b=2.0 # b=np.array([2.0]) # output = tf.import_graph_def(graph_def, input_map={'input:0':a}, return_elements=['out:0'], name='a') output = tf.import_graph_def(graph_def, input_map={'input_0':a,'input_1':b}, return_elements=['out_1:0'], name='a1')#输出两个 # output = tf.import_graph_def(graph_def, input_map={'input_0':a,'input_1':b}, return_elements=['out_0:0'], name='a1') #只输出一个 print(sess.run(output)) def predict2(): with gfile.FastGFile("./graph.pb",'rb') as f: tf.reset_default_graph() graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') with tf.Session() as sess: # tf.reset_default_graph() # sess.run(tf.initialize_all_variables()) input_x = sess.graph.get_tensor_by_name("input_0:0") input_x2 = sess.graph.get_tensor_by_name("input_1:0") output0 = sess.graph.get_tensor_by_name("out_0:0") output1 = sess.graph.get_tensor_by_name("out_1:0") output1oss = sess.graph.get_tensor_by_name("out_loss:0") print(output1oss) # output2 = sess.graph.get_tensor_by_name("out_0:0") a=np.array([1.0]) b=np.array([2.0]) result = sess.run([output0,output1oss], {input_x: a.reshape(-1,1),input_x2: b.reshape(-1,1)}) print (result ) # Const = sess.graph.get_tensor_by_name("Const:0") # print Const # output = sess.graph.get_operation_by_name("output") # print output def main(): train()# # predict()#pb的保存方式能减少很多麻烦,并且可以多次调用不报错 predict2() if __name__ == '__main__': main()