TensorFlow中tf.add_to_collection,tf.get_collecton和tf.add_n的使用方法简介

tf.add_to_collection(name, value)

把 value 中的变量放入到 name 作为 key 的一个集合里,也就是把多个变量统一放在一个列表中。

# 在 loss 作为 key 的集合中添加变量
>>> tf.add_to_collection('loss', tf.Variable(tf.truncated_normal([3, 3], stddev = 0.1), name = 'var1'))
>>> tf.add_to_collention('loss', tf.Variable(tf.truncated_normal([3, 3], stddev = 0.2), name = 'var2'))

tf.get_collection(key, scope = None)

从 key 确定的集合中取出 name 为 scope 的元素,如果 scope 为 None,则取出集合中的所有元素。

# var1 为一个 tf.Variaible 组成的列表
>>> var1 =  tf.get_collection('loss', 'var1')
[<tf.Variable 'var1:0' shape=(3, 3) dtype=float32_ref>]

tf.add_n(inputs, name = None)

将 inputs 中的所有 tensor 进行相加。
inputs: A list of Tensor or IndexedSlices objects, each with same shape and type.
name:A name for the operation (optional).

# 注意 inputs 参数应该为 list 类型
>>> tf.add_n(tf.get_collection('loss'), name = 'add_n')

猜你喜欢

转载自blog.csdn.net/xhj_enen/article/details/88245662
今日推荐