16. Save and restore TensorFlow model parameters

The easiest way to save and restore a model is to use an tf.train.Saver()object graphthat adds and to all variables in , or variables defined in a list . The object provides methods to run these , and specifies the read and write paths to the checkpoint file.saverestore opstf.train.Saver()ops


一、tf.train.Saver() 类解析

tf.train.Saver(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

1、初始化参数解析

  • var_list
    • specifies the variables that will be saved and restored. If None, defaults to the list of all saveable objects. It can be passed as a dict or a list:
    • A dict of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files.
    • A list of variables: The variables will be keyed with their op name in the checkpoint files.
  • For example:
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
  • max_to_keep
    • indicates the maximum number of recent checkpoint files to keep.
    • As new files are created, older files are deleted.
    • If None or 0, all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.)
    • Setting max_to_keep=1only saves the latest model, or when using the save()method to save the model, keep global_step=Nonecan also achieve modelthe effect of saving only the latest.

2、常用方法解析

# Returns a string, path at which the variables were saved.
save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

# The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables.
restore(
    sess,
    save_path
)

2. Save and restore parameters

1. Introduction to checkpoint files

  • Variables are stored in binary files, mainly including variable names to tensor valuesthe mapping relationship of slaves
  • When you create one Saver对象, you can optionally pick variable names for variables in the checkpoint file. By default, each variable will be used tf.Variable.name 属性的值. (This is the parameter of the model, and it has nothing to do with the variable name )
  • saver = tf.train.Saver(max_to_keep=3)The details of the files saved by checkpoint are as follows:
    • The first file holds all 模型文件路径listings in a directory
    • The second file holds our model ( variable names to tensor values ​​)
    • The third file is the index
    • The fourth file is the structure of the computational graph

write picture description here


2. Save variable & restore variable

  • You can use a booltype variable is_trainto control 训练和验证two stages, Truerepresenting training and Falserepresenting testing
  • tf.train.Saver()The class supports renaming the variable when restoring the variable (overwriting the nameparameters )
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import tensorflow as tf

# Create some variables.
w = tf.get_variable("weight", shape=[2], initializer=tf.zeros_initializer())
b = tf.get_variable("bias", shape=[3], initializer=tf.zeros_initializer())

inc_w = w.assign(w + 1)
dec_b = b.assign(b - 1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver(max_to_keep=3)

isTrain = False  # True 表示训练,False 表示测试
train_steps = 1000
checkpoint_steps = 50
checkpoint_dir = 'checkpoint/save&restore/'
model_name = 'my_model'

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
    sess.run(init_op)
    if isTrain:
        # Do some work with the model.
        for step in range(train_steps):
            inc_w.op.run()
            dec_b.op.run()
            if (step + 1) % checkpoint_steps == 0:
                # Append the step number to the checkpoint name:
                saved_path = saver.save(
                    sess,
                    checkpoint_dir + model_name,
                    global_step=step + 1  # 设为 None 时,只保存最新结果
                )
    else:
        print('Before restore:')
        print(sess.run(w))
        print(sess.run(b))
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        # 获取最新的 model_file
        if ckpt and ckpt.model_checkpoint_path:
            print("Success to load %s." % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print('After restore:')
        print(sess.run(w))
        print(sess.run(b))

# 测试结果
Before restore:
[ 0.  0.]
[ 0.  0.  0.]
Success to load checkpoint/save&restore/my_model-1000.
After restore:
[ 1000.  1000.]
[-1000. -1000. -1000.]


# 结论:restore 其实就相当于重新初始化所有的变量

# 结论分析
虽然官方文档说:restore 时不用使用 init_op 去初始化所有的变量了,但这里为了验证下(restore 其实就相当于重新初始化所有的变量),还是把 sess.run(init_op) 放在了if isTrain: 语句的上面(同时作用于训练和测试阶段), 从测试结果中可以验证结论。
# 其实可以把 sess.run(init_op) 放在 if isTrain: 语句的里面(只作用于训练阶段)

3. Obtain the values ​​of trainable parameters & extract the features of a certain layer

sess = tf.Session()

# Returns all variables created with trainable=True in a var_list
var_list = tf.trainable_variables()

print("Trainable variables:------------------------")

# 取出所有可训练参数的索引、形状和名称
for idx, v in enumerate(var_list):
     print("param {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))


# 某网络输出示例
Trainable variables:------------------------
  param   0: (5, 5, 3, 32)     conv2d/kernel:0
  param   1: (32,)             conv2d/bias:0
  param   2: (5, 5, 32, 64)    conv2d_1/kernel:0
  param   3: (64,)             conv2d_1/bias:0
  param   4: (3, 3, 64, 128)   conv2d_2/kernel:0
  param   5: (128,)            conv2d_2/bias:0
  param   6: (3, 3, 128, 128)   conv2d_3/kernel:0
  param   7: (128,)            conv2d_3/bias:0
  param   8: (4608, 1024)      dense/kernel:0
  param   9: (1024,)           dense/bias:0
  param  10: (1024, 512)       dense_1/kernel:0   --->dense2 层的参数
  param  11: (512,)            dense_1/bias:0
  param  12: (512, 5)          dense_2/kernel:0
  param  13: (5,)              dense_2/bias:0


# 提取最后一个全连接层的参数 W 和 b
W = sess.run(var_list[12])
b = sess.run(var_list[13])

# 提取第二个全连接层的输出值作为特征    
feature = sess.run(dense2, feed_dict={x:img})

3. Continue training & Fine-tune a certain layer

1. Continue training (all parameters)

# 定义一个全局对象来获取参数的值,在程序中使用(eg:FLAGS.iteration)来引用参数
FLAGS = tf.app.flags.FLAGS


# 定义命令行参数,第一个是:参数名称,第二个是:参数默认值,第三个是:参数描述
tf.app.flags.DEFINE_string(
    "checkpoint_dir", 
    "/path/to/checkpoint_save_dir/", 
    "Directory name to save the checkpoints [checkpoint]"
)
tf.app.flags.DEFINE_boolean(
    "continue_train", 
    False, 
    "True for continue training.[False]"
)

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    if FLAGS.continue_train:
        # 自动取得最新的 model_file
        model_file = tf.train.latest_checkpoint(checkpoint_dir)
        saver.restore(sess, model_file)
        print("Success to load %s." % model_file)

2. A layer of Fine-tune

  • Change the parameters of the weights and biases in the network, and set the trainableparameters toFalse
  • Then use the above code to continue training ,
    eg: my_non_trainable = tf.get_variable("my_non_trainable", shape=(3, 3), trainable=False)
  • Restore a meta checkpoint ( to be concluded????? )
    • use the TF helpertf.train.import_meta_graph()

4. References

1. https://www.tensorflow.org/api_docs/python/tf/train/Saver
2. tensorflow learning: model saving and restoration (Saver)
3. Tensorflow series - the usage of Saver
4. tensorflow 1.0 learning: parameters and feature extraction
5, https://www.tensorflow.org/api_guides/python/meta_graph
6, https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325447186&siteId=291194637