Tensorflow中tf.train.ExponentialMovingAverage()函数

先解释一下滑动平均的作用:

Some training algorithms, such as GradientDescent and Momentum often benefit from maintaining a moving average of variables during optimization. Using the moving averages for evaluations often improve results significantly.

一些训练算法(如GradientDescent和Momentum)通常可以在优化过程中保持变量的移动平均值。 使用移动平均值进行评估通常会显着改善结果。


什么是滑动平均:

   

这种求平均值的好处是可以产生一个迭代效果,不必在求t+1时刻的值得时候还要保存a1,a2,a3...at的值,只保存一个之前时刻的平均值mvt就可以了,显然是衰减系数,如果t很大的时候decay接近于1,也就是0.999


这样有什么好处:

     主要是通过控制衰减率来控制参数更新前后之间的差距,从而达到减缓参数的变化值(如,参数更新前是5,更新后的值是4,通过滑动平均模型之后,参数的值会在4到5之间),如果参数更新前后的值保持不变,通过滑动平均模型之后,参数的值仍然保持不变。


Tensorflow下的滑动平均:

TensorFlow下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率decay。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个待更新的变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

                            

在滑动平滑模型中, decay 决定了模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.999或0.9999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小

                                            

看看后面那个参数有没有很像t/(t+1),num_updates就是迭代的次数。

import tensorflow as tf
graph=tf.Graph()
with graph.as_default():
    w = tf.Variable(dtype=tf.float32,initial_value=1.0)
    ema = tf.train.ExponentialMovingAverage(0.9)
    update = tf.assign_add(w, 1.0)#更新w,相当于w=w+1.0,

    with tf.control_dependencies([update]):#ema_op的执行依赖于update的执行,也就是在执行ema_op之前先执行update
        ema_op = ema.apply([w])#返回一个op,这个op用来更新moving_average #这句和下面那句不能调换顺序

    ema_val = ema.average(w)#此op用来返回当前的moving_average,这个参数不能是list

with tf.Session(graph=graph) as sess:
    sess.run(tf.initialize_all_variables())
    for i in range(3):
        print i
        print 'w_old=',sess.run(w)
        print sess.run(ema_op)
        print 'w_new=', sess.run(w)
        print sess.run(ema_val)
        print '**************'

输出结果:

0
w_old= 1.0
None
w_new= 2.0#在执行ema_op时先执行了对w的更新
1.1  #0.9*1.0+0.1*2.0=1.1
**************
1
w_old= 2.0
None
w_new= 3.0
1.29  #0.9*1.1+0.1*3.0=1.29
**************
2
w_old= 3.0
None
w_new= 4.0
1.561  #0.9*1.29+0.1*4.0=1.561

再见

猜你喜欢

转载自blog.csdn.net/mieleizhi0522/article/details/80424731