保存完整的模型
生成一下4个文件:
checkpoint
.meta
.data-00000-of-00001
.index
使用此方法保存时,会生成3个文件,其中后缀为.meta
保存图的结构和常量,.data
和.index
保存模型的权重,偏差,梯度其他的变量的所有的值等。
-
创建服务对象,如果
Saver()
有传入,表示只对传入的值有相应的后续效果saver = tf.train.Saver()
-
适应
saver
进行模型保存saver.save(sess,保存的文件名)
导出模型为一个文件
将图中所有的变量装换为常量tf.graph_util.convert_variables_to_constants()
写入到一个模型中tf.train.writer_graph()
import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 这里需要填入输出tensor的名字
graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)
保存部分模型
使用tf.global_varibales()
获取所有变量,通过变量所有变量保存满足条件的val.name
的变量, 使用tf.train.Saver()
保存
var=tf.globale_variables()
var_to_restore=[val for val in var if 'conv1' in val.name or 'conv2' in val.name]
saver=tf.train.Saver(var_to_restore)
saver.restore(sess, model_dir)
# 初始化要保存的变量
模型加载
模型加载使用, 不能在适用tf.global_variable_initialize
进行初始化, 否则导致无效.
单个文件模型的加载
def create_graph():
# 读取
with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
官方无源代码模型加载
# 加载图 check_dir + '.meta'
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file))
# 加载权重
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
节点定位
使用tf.summary.Filewriter()
进行可视化, 在tensorboard中找到需要的模型的节点
# 获取默认图对象
graph=tf.get_default_graph()
# 得到图中所有的节点
all_op = graph.get_operations()
# 获取需要的节点
layers = [op.name for op in all_op if 'stack' in op.name and op.type == 'Pack']
# 得到节点输出的张量, 其中":0" 表示该节点的第1个输出
op_out_tensor = graph.get_tensor_by_name(layer[0]+':0')