tensorflow中如何average checkpoint

首先获取checkpoint的状态以及每个参数的值:

ckpt_state = tf.train.get_checkpoint_state(model_dir)
ckpts = (ckpt_state.all_model_checkpoint_paths)

avg_model_dir = os.path.join(model_dir, "avg_ckpts")
tf.gfile.MakeDirs()

var_list = tf.contrib.framework.list_variables(ckpts[0])

然后对checkpoint里的每个参数求平均:

var_values, var_dtypes = {}, {}

for (name, shape) in var_list:
    var_values[name] = np.zeros(shape)

for ckpt in ckpts:
    reader = tf.contrib.framework.load_checkpoint(ckpt)
    for name in var_values:
        tensor = reader.get_tensor(name)
        var_dtypes[name] = tensor.dtype
        var_values[name] += tensor

for name in var_values:
    var_values[name] /= len(ckpts)

接下来将平均后的参数保存在一个新的checkpoint里面:

tf_vars = [tf.get_variable(name, dtype=var_dtypes[name], initializer=var_values[name]) for name in var_values]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess, os.path.join(avg_model_dir, "qe.ckpt"))

猜你喜欢

转载自blog.csdn.net/bonjourdeutsch/article/details/102662437
今日推荐