tf.train.ExponentialMovingAverage用法

tf.train.ExponentialMovingAverage:通过采用指数衰减保持变量的移动平均值

tf.train.ExponentialMovingAverage(
    decay, num_updates=None, zero_debias=False, name='ExponentialMovingAverage'
)

训练模型时,保持训练参数的移动平均值通常是有益的。 使用平均参数的评估有时会产生比最终训练值明显更好的结果。

apply()方法添加训练变量的影子副本,并添加操作在其影子副本中保持训练变量的移动平均值。 在构建训练模型时使用它。维持移动平均值的操作通常在每个训练步骤之后执行 average()average_name()方法可访问影子变量及其名称。在构建评估模型从检查点文件还原模型时,它们很有用。 他们有助于使用移动平均值代替上次训练的值进行评估。

移动平均值是使用指数衰减来计算的。 在创建ExponentialMovingAverage对象时,可以指定衰减值。

影子变量使用与训练变量相同的初始值进行初始化。 当运行ops来维持移动平均值时,每个影子变量都会使用以下公式进行更新:

shadow_variable -= (1 - decay) * (shadow_variable - variable)

从数学上讲,这等效于下面的经典公式,但是使用assign_sub 操作(公式中的“-=”)允许并发无锁更新变量:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

合理的衰减值接近1.0,通常在多个九度范围内:0.999、0.9999等。

使用方法:ExponentialMovingAverage()创建一个新的ExponentialMovingAverage对象。必须调用`apply()`方法来创建影子变量并添加操作以维持移动平均值。可选的num_updates参数允许动态调整衰减率。 通常要传递训练步骤的数量,通常保持在每个步骤中递增的变量中,在这种情况下,衰减速率在训练开始时会较低。 这使移动平均值移动得更快。 如果传递,则使用的实际衰减率是:

min(decay, (1 + num_updates) / (10 + num_updates))

示例程序:

import tensorflow as tf

v1 = tf.Variable(0, dtype=tf.float32)
step = tf.Variable(tf.constant(0))

ema = tf.train.ExponentialMovingAverage(0.99, step)  
# 创建一个新的ExponentialMovingAverage对象ema
maintain_average = ema.apply([v1])  
# 调用apply()方法来创建变量v1的影子变量,并添加操作以维持移动平均值

with tf.Session() as sess:
    init = tf.initialize_all_variables()  # 定义初始化变量操作
    sess.run(init)  # 执行初始化变量操作

    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))  
    # 初始的值都为0,average()和average_name()方法可访问变量v1的影子变量及其名称

    sess.run(tf.assign(v1, 5))  # 把v1变为5
    sess.run(maintain_average)  # 执行maintain_average
    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))  
    # decay=min(0.99, 1/10)=0.1, v1_shadow=0.1*0+0.9*5=4.5

    sess.run(tf.assign(step, 10000))  # steps=10000
    sess.run(tf.assign(v1, 10))  # v1=10
    sess.run(maintain_average)
    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))
    # decay=min(0.99,(1+10000)/(10+10000))=0.99,v1_shadow=0.99*4.5+0.01*10=4.555

    sess.run(maintain_average)
    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))
    # decay=min(0.99,(1+10000)/(10+10000))=0.99,
    # v1_shadow=0.99*4.555+0.01*10=4.609449999999999

# 输出结果:
# [0.0, 0.0, None]
# [5.0, 4.5, None]
# [10.0, 4.555, None]
# [10.0, 4.60945, None]

猜你喜欢

转载自blog.csdn.net/qq_36201400/article/details/108208042
今日推荐