tensorflow模型保存, 单个文件, 模型读取, 部分模型保存,读取

保存完整的模型

生成一下4个文件:

checkpoint
.meta
.data-00000-of-00001
.index

使用此方法保存时,会生成3个文件,其中后缀为.meta保存图的结构和常量.data.index保存模型的权重,偏差,梯度其他的变量的所有的值等。

  • 创建服务对象,如果Saver()有传入,表示只对传入的值有相应的后续效果

    saver = tf.train.Saver()
    
  • 适应saver进行模型保存

    saver.save(sess,保存的文件名)
    

导出模型为一个文件

将图中所有的变量装换为常量tf.graph_util.convert_variables_to_constants()

写入到一个模型中tf.train.writer_graph()

import tensorflow as tf

w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)

#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

保存部分模型

使用tf.global_varibales()获取所有变量,通过变量所有变量保存满足条件的val.name的变量, 使用tf.train.Saver()保存

var=tf.globale_variables()
var_to_restore=[val for val in var if 'conv1' in val.name or 'conv2' in val.name]
saver=tf.train.Saver(var_to_restore)
saver.restore(sess, model_dir)
# 初始化要保存的变量

模型加载

模型加载使用, 不能在适用tf.global_variable_initialize进行初始化, 否则导致无效.

单个文件模型的加载

def create_graph():
  # 读取
  with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

官方无源代码模型加载

# 加载图 check_dir + '.meta'
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file))
# 加载权重
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))

节点定位

使用tf.summary.Filewriter()进行可视化, 在tensorboard中找到需要的模型的节点

# 获取默认图对象
graph=tf.get_default_graph()
# 得到图中所有的节点
all_op = graph.get_operations()
# 获取需要的节点
layers = [op.name for op in all_op if 'stack' in op.name and op.type == 'Pack']
# 得到节点输出的张量, 其中":0" 表示该节点的第1个输出
op_out_tensor = graph.get_tensor_by_name(layer[0]+':0')

猜你喜欢

转载自blog.csdn.net/qq_39124762/article/details/82945276