tf.truncated_normal()函数解析

最近在看batch normalization对的代码,碰到tf.truncated_normal(),特此记录。

tf.truncated_normal(
	shape, 
	mean=0.0, 
	stddev=1.0, 
	dtype=tf.float32, 
	seed=None, 
	name=None
)

参数:

  • shape: 一维的张量,也是输出的张量。
  • mean: 正态分布的均值。
  • stddev: 正态分布的标准差。
  • dtype: 输出的类型。
  • seed: 一个整数,当设置之后,每次生成的随机数都一样。
  • name: 操作的名字。

从截断的正态分布中输出随机值。

生成的值服从具有指定平均值和标准偏差的正态分布,如果生成的值大于平均值2个标准偏差的值则丢弃重新选择。

例子1:

import tensorflow as tf

a = tf.truncated_normal(shape=[3, 3], mean=0, stddev=1)

with tf.Session() as sess:
    print(sess.run(a))
> [[-0.47420496 -0.4512505   1.168068  ]
   [-0.6720735  -1.0473527   0.71249765]
   [-1.529904   -0.6563566   0.561471  ]]

例子2:

import tensorflow as tf

a = tf.Variable(tf.random_normal([2,2],seed=1))
b = tf.Variable(tf.truncated_normal([2,2],seed=2))
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(a))
    print(sess.run(b))
> [[-0.8113182   1.4845988 ]
   [ 0.06532937 -2.4427042 ]]
  [[-0.85811085 -0.19662298]
   [ 0.13895045 -1.2212768 ]]

指定seed之后,a的值不变,b的值也不变。

猜你喜欢

转载自blog.csdn.net/TeFuirnever/article/details/88927892