BN的过程,具体是怎样计算均值和方差的?


对于一个小批次的图像样本,NCHW [128,3,10,10], BN的过程,具体是怎样计算均值和方差的?
下来找到部分相关代码如下:
  def call(self, inputs, training=False):
    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value
    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
    scale, offset = self.gamma, self.beta
    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
最后一行的 reduction_axes 去除的元素是在如下的代码: 
axis = 1 if data_format == DATA_FORMAT_NCHW else -1
效果就是在C channel这一维计算均值和方差!

先搜了一下moments这个东东。。名字是矩? moment是动量?  其实 其原始含义是“to move"或者“移动”,这样就好理解了。不了解的同学也先自行了解吧。
其实在nn.moments的注释里有提示:
  When using these moments for batch normalization (see
  `tf.nn.batch_normalization`):
   * for so-called "global normalization", used with convolutional filters with
     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
   * for simple batch normalization pass `axes=[0]` (batch only).
  
  
  
看一个例子:
import tensorflow as tf
a = []
for i in range(24):
    for j in range(5):
        a.append(i+1)
b = tf.constant(a,shape= [2,3,4,5])
axis1 = list(range(len(shape)-1))  #从最后一维计算
axis2 = list(range(len(shape)))   
del axis2[1]                                  #模仿NCHW,从chanel维计算
end_mean, end_var = tf.nn.moments(b, axis1)
cha_mean, cha_var = tf.nn.moments(b, axis2)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for temp in [b,end_mean,cha_mean]:
        print ('\n',sess.run(temp))

打印结果:
 [[[[ 1  1  1  1  1]
   [ 2  2  2  2  2]
   [ 3  3  3  3  3]
   [ 4  4  4  4  4]]

  [[ 5  5  5  5  5]
   [ 6  6  6  6  6]
   [ 7  7  7  7  7]
   [ 8  8  8  8  8]]

  [[ 9  9  9  9  9]
   [10 10 10 10 10]
   [11 11 11 11 11]
   [12 12 12 12 12]]]


 [[[13 13 13 13 13]
   [14 14 14 14 14]
   [15 15 15 15 15]
   [16 16 16 16 16]]

  [[17 17 17 17 17]
   [18 18 18 18 18]
   [19 19 19 19 19]
   [20 20 20 20 20]]

  [[21 21 21 21 21]
   [22 22 22 22 22]
   [23 23 23 23 23]
   [24 24 24 24 24]]]]

 [12 12 12 12 12]

 [ 8 12 16]

猜你喜欢

转载自blog.csdn.net/anthea_luo/article/details/80860579
今日推荐