对于一个小批次的图像样本,NCHW [128,3,10,10], BN的过程,具体是怎样计算均值和方差的?
下来找到部分相关代码如下:
def call(self, inputs, training=False):
# First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters.
input_shape = inputs.get_shape()
ndim = len(input_shape)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis].value
def call(self, inputs, training=False):
# First, compute the axes along which to reduce the mean / variance,
# as well as the broadcast shape to be used for all parameters.
input_shape = inputs.get_shape()
ndim = len(input_shape)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis].value
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
scale, offset = self.gamma, self.beta
# Determine a boolean value for `training`: could be True, False, or None.
training_value = utils.constant_value(training)
if training_value is not False:
# Some of the computations here are not necessary when training==False
# but not a constant. However, this makes the code simpler.
mean, variance = nn.moments(inputs, reduction_axes)
training_value = utils.constant_value(training)
if training_value is not False:
# Some of the computations here are not necessary when training==False
# but not a constant. However, this makes the code simpler.
mean, variance = nn.moments(inputs, reduction_axes)
最后一行的 reduction_axes 去除的元素是在如下的代码:
axis = 1 if data_format == DATA_FORMAT_NCHW else -1
效果就是在C channel这一维计算均值和方差!
axis = 1 if data_format == DATA_FORMAT_NCHW else -1
效果就是在C channel这一维计算均值和方差!
先搜了一下moments这个东东。。名字是矩? moment是动量? 其实 其原始含义是“to move"或者“移动”,这样就好理解了。不了解的同学也先自行了解吧。
其实在nn.moments的注释里有提示:
When using these moments for batch normalization (see
`tf.nn.batch_normalization`):
When using these moments for batch normalization (see
`tf.nn.batch_normalization`):
* for so-called "global normalization", used with convolutional filters with
shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
* for simple batch normalization pass `axes=[0]` (batch only).
看一个例子:
import tensorflow as tf
a = []
for i in range(24):
for j in range(5):
a.append(i+1)
shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
* for simple batch normalization pass `axes=[0]` (batch only).
看一个例子:
import tensorflow as tf
a = []
for i in range(24):
for j in range(5):
a.append(i+1)
b = tf.constant(a,shape= [2,3,4,5])
axis1 = list(range(len(shape)-1)) #从最后一维计算
axis2 = list(range(len(shape)))
del axis2[1] #模仿NCHW,从chanel维计算
axis2 = list(range(len(shape)))
del axis2[1] #模仿NCHW,从chanel维计算
end_mean, end_var = tf.nn.moments(b, axis1)
cha_mean, cha_var = tf.nn.moments(b, axis2)
cha_mean, cha_var = tf.nn.moments(b, axis2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for temp in [b,end_mean,cha_mean]:
print ('\n',sess.run(temp))
sess.run(tf.global_variables_initializer())
for temp in [b,end_mean,cha_mean]:
print ('\n',sess.run(temp))
打印结果:
[[[[ 1 1 1 1 1]
[ 2 2 2 2 2]
[ 3 3 3 3 3]
[ 4 4 4 4 4]]
[[[[ 1 1 1 1 1]
[ 2 2 2 2 2]
[ 3 3 3 3 3]
[ 4 4 4 4 4]]
[[ 5 5 5 5 5]
[ 6 6 6 6 6]
[ 7 7 7 7 7]
[ 8 8 8 8 8]]
[ 6 6 6 6 6]
[ 7 7 7 7 7]
[ 8 8 8 8 8]]
[[ 9 9 9 9 9]
[10 10 10 10 10]
[11 11 11 11 11]
[12 12 12 12 12]]]
[10 10 10 10 10]
[11 11 11 11 11]
[12 12 12 12 12]]]
[[[13 13 13 13 13]
[14 14 14 14 14]
[15 15 15 15 15]
[16 16 16 16 16]]
[[17 17 17 17 17]
[18 18 18 18 18]
[19 19 19 19 19]
[20 20 20 20 20]]
[18 18 18 18 18]
[19 19 19 19 19]
[20 20 20 20 20]]
[[21 21 21 21 21]
[22 22 22 22 22]
[23 23 23 23 23]
[24 24 24 24 24]]]]
[22 22 22 22 22]
[23 23 23 23 23]
[24 24 24 24 24]]]]
[12 12 12 12 12]
[ 8 12 16]