tf.nn.sufficient_statistics(x, axes, shift=None,keep_dims=False, name=None)

参考:https://www.cnblogs.com/hellcat/p/6906065.html


tf.nn.sufficient_statistics(x, axes, shift=None,keep_dims=False, name=None)    

1. 功能:计算与均值和方差有关的完全统计量

2. 返回:4维元组      ->     ( 元素个数,元素加和,元素的平方和,shift )

3. 示例

# tf.__version__  ->  1.4.0
import tensorflow as tf

size = 3
W = tf.constant([[1., 2., 3.], [4., 5., 6.]])
shift = tf.Variable(tf.zeros([size]))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ss = sess.run(tf.nn.sufficient_statistics(W, axes=[0], shift=shift)) # axes=[0]表示按列计算
    print(ss)
    for i in ss:
        print(i)

输出:

(2.0, array([5., 7., 9.], dtype=float32), array([17., 29., 45.], dtype=float32), array([0., 0., 0.], dtype=float32))
2.0
[5. 7. 9.]
[17. 29. 45.]
[0. 0. 0.]

4. 分析源码:

def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
  """Calculate the sufficient statistics for the mean and variance of `x`.

  These sufficient statistics are computed using the one pass algorithm on
  an input that's optionally shifted. See:
  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data

  Args:
    x: A `Tensor`.
    axes: Array of ints. Axes along which to compute mean and variance.
    shift: A `Tensor` containing the value by which to shift the data for
      numerical stability, or `None` if no shift is to be performed. A shift
      close to the true mean provides the most numerically stable results.
    keep_dims: produce statistics with the same dimensionality as the input.
    name: Name used to scope the operations that compute the sufficient stats.

  Returns:
    Four `Tensor` objects of the same type as `x`:

    * the count (number of elements to average over).
    * the (possibly shifted) sum of the elements in the array.
    * the (possibly shifted) sum of squares of the elements in the array.
    * the shift by which the mean must be corrected or None if `shift` is None.
  """
  axes = list(set(axes))
  with ops.name_scope(name, "sufficient_statistics", [x, shift]):
    x = ops.convert_to_tensor(x, name="x")
    x_shape = x.get_shape()
    if all(x_shape[d].value is not None for d in axes):
      counts = 1
      for d in axes:
        counts *= x_shape[d].value
      counts = constant_op.constant(counts, dtype=x.dtype)
    else:  # shape needs to be inferred at runtime.
      x_dims = array_ops.gather(
          math_ops.cast(array_ops.shape(x), x.dtype), axes)
      counts = math_ops.reduce_prod(x_dims, name="count")
    if shift is not None:
      shift = ops.convert_to_tensor(shift, name="shift")
      m_ss = math_ops.subtract(x, shift)
      v_ss = math_ops.squared_difference(x, shift)
    else:  # no shift.
      m_ss = x
      v_ss = math_ops.square(x)
    m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
    v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
  return counts, m_ss, v_ss, shift

4.1 元素个数的计算:

for d in axes:
        counts *= x_shape[d].value

这样就清楚元素个数是怎么得来的了,其实就是累乘axes里给出的维度对应的大小。

4.2 squared_difference(x, shift)会返回什么呢?看源码:

def squared_difference(x, y, name=None):
  r"""Returns (x - y)(x - y) element-wise.
  # ...省略...

返回:( x - shift )^2

5. 当参数 shift=None 时遇到的Bug:

# tf.__version__  ->  1.4.0
import tensorflow as tf

size = 3
W = tf.constant([[1., 2., 3.], [4., 5., 6.]])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ss = sess.run(tf.nn.sufficient_statistics(W, axes=[0,1])) # axes=[0]表示按列计算
    print(ss)

Bug:

# ...省略...
TypeError: Fetch argument None has invalid type <class 'NoneType'>

这是因为当参数 shift=None时,tf.nn.sufficient_statistics(W, axes=[0,1])也会返回一个None!

不用 sess.run(tf.nn.sufficient_statistics(W, axes=[0,1])),直接 print(tf.nn.sufficient_statistics(W, axes=[0,1]))是OK的:

(<tf.Tensor 'sufficient_statistics/Const:0' shape=() dtype=float32>, <tf.Tensor 'sufficient_statistics/mean_ss:0' shape=() dtype=float32>, <tf.Tensor 'sufficient_statistics/var_ss:0' shape=() dtype=float32>, None)

不用怀疑!返回元组的最后一个元素 shift 就是 None!

然后,用了一个很笨的办法输出返回值:

# tf.__version__  ->  1.4.0
import tensorflow as tf

size = 3
W = tf.constant([[1., 2., 3.], [4., 5., 6.]])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    A = tf.nn.sufficient_statistics(W, axes=[0]) # axes=[0]表示按列计算
    for i, Ai in enumerate(A):
        if Ai != None:
            print(i, Ai.eval())
        else:
            print(i, None)

输出:

0 2.0
1 [5. 7. 9.]
2 [17. 29. 45.]
3 None

猜你喜欢

转载自blog.csdn.net/ranmw1129/article/details/80810538
今日推荐