tf.nn.moments()函数理解

官方的输入定义如下:

def moments(x, axes, name=None, keep_dims=False)

解释如下:

    x 可以理解为我们输出的数据,形如 [batchsize, height, width, kernels]
    axes 表示在哪个维度上求解,是个list,例如 [0, 1, 2]
    name 就是个名字,不多解释
    keep_dims 是否保持维度,不多解释

这个函数的输出有两个,用官方的话说就是:

    Two Tensor objects: mean andvariance.

解释如下:

    mean 就是均值啦
    variance 就是方差啦

个人试验后总结理解就是将x上除去axes所指定的纬度的剩余纬度组成的各个子元素看做个体,个体中的每个位置的值看做个体的不同位置属性,然后求所有个体在每种位置属性上的均值和方差

例子如下

eg1:

代码:

# coding: utf-8
import tensorflow as tf
img = tf.Variable(tf.random_normal([128, 4, 2, 3]))
axis = [0,1,2]#所以剩余的是四纬看做一个整体shape为[3]
mean, variance = tf.nn.moments(img, axis)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    mean_, variance_ = sess.run([mean, variance])
    print("均值:",mean_)
    print("方差:",variance_)

结果:

均值: [-0.03587267  0.06021447  0.02401767]
方差: [0.99473494 0.93040663 0.98113006]

eg2:

代码:

# coding: utf-8
import tensorflow as tf
img = tf.Variable(tf.random_normal([128, 4, 2, 3]))
axis = [0,1]#所以剩余的是第三 四纬看做一个整体shape为[2, 3]
mean, variance = tf.nn.moments(img, axis)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    mean_, variance_ = sess.run([mean, variance])
    print("均值:",mean_)
    print("方差:",variance_)

结果:

均值: [[-0.04313184  0.01417894  0.06847101]
 [ 0.04183875 -0.01508999 -0.11406976]]
方差: [[0.976376   0.91841435 1.0207324 ]
 [1.0403597  0.9773739  1.0360421 ]]

猜你喜欢

转载自blog.csdn.net/qq_16320025/article/details/89397883
今日推荐