tf.GraphKeys和变量初始化

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u010365819/article/details/88125375

GraphKeys

tf.GraphKeys包含所有graph collection中的标准集合名,有点像Python里的build-in fuction。

首先要了解graph collection是什么。

graph collection

在官方教程——图和会话中,介绍什么是tf.Graph是这么说的:

tf.Graph包含两类相关信息:

  • 图结构。图的节点和边缘,指明了各个指令组合在一起的方式,但不规定它们的使用方式。图结构与汇编代码类似:检查图结构可以传达一些有用的信息,但它不包含源代码传达的的所有有用上下文。
  • **图集合。**TensorFlow提供了一种通用机制,以便在tf.Graph中存储元数据集合。tf.add_to_collection函数允许您将对象列表与一个键相关联(其中tf.GraphKeys定义了部分标准键),tf.get_collection则允许您查询与键关联的所有对象。TensorFlow库的许多组成部分会使用它:例如,当您创建tf.Variable时,系统会默认将其添加到表示“全局变量(tf.global_variables)”和“可训练变量tf.trainable_variables)”的集合中。当您后续创建tf.train.Savertf.train.Optimizer时,这些集合中的变量将用作默认参数。

也就是说,在创建图的过程中,TensorFlow的Python底层会自动用一些collection对op进行归类,方便之后的调用。这部分collection的名字被称为tf.GraphKeys,可以用来获取不同类型的op。当然,我们也可以自定义collection来收集op。

常见GraphKeys

  • GLOBAL_VARIABLES: 该collection默认加入所有的Variable对象,并且在分布式环境中共享。一般来说,TRAINABLE_VARIABLES包含在MODEL_VARIABLES中,MODEL_VARIABLES包含在GLOBAL_VARIABLES中。
  • LOCAL_VARIABLES: GLOBAL_VARIABLES不同的是,它只包含本机器上的Variable,即不能在分布式环境中共享。
  • MODEL_VARIABLES: 顾名思义,模型中的变量,在构建模型中,所有用于正向传递的Variable都将添加到这里。
  • TRAINALBEL_VARIABLES: 所有用于反向传递的Variable,即可训练(可以被optimizer优化,进行参数更新)的变量。
  • SUMMARIES: 跟Tensorboard相关,这里的Variable都由tf.summary建立并将用于可视化。
  • QUEUE_RUNNERS: the QueueRunner objects that are used to produce input for a computation. 
  • MOVING_AVERAGE_VARIABLES: the subset of Variable objects that will also keep moving averages.
  • REGULARIZATION_LOSSES: regularization losses collected during graph construction.

在TensorFlow中也定义了下面几个GraphKeys,但是它们not automatically populated

  • WEIGHTS
  • BIASES
  • ACTIVATIONS

使用TensorFlow的时候定义了变量之后还需要初始化之后才能使用,不然会报错:Attempting to use uninitialized value,下面介绍TensorFlow中常用的几种初始化操作

1、tf.global_variable_initializer

官方介绍地址:https://tensorflow.google.cn/api_docs/python/tf/initializers/global_variables

用来初始化计算图中的全局的变量,全局变量是指创建的变量在tf.GraphKeys.GLOBAL_VARIABLES中,在使用Variable创建变量时默认是collections默认是tf.GraphKeys.GLOBAL_VARIABLES

if __name__ == "__main__":
    v = tf.Variable(1)
    c = tf.constant(2)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(v))
        print(sess.run(c))
2、tf.initialize_all_variables

2017年3月2日之后使用global_variables_initializer()来代替tf.initialize_all_variables()。

3、tf.initialize_local_variables

初始化计算图中所有的局部变量,局部变量是指创建的变量在tf.GraphKeys.LOCAL_VARIABLES中,在使用saver的时候,局部变量是不存在在模型文件中的

if __name__ == "__main__":
    a = tf.Variable(1,name="a",collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        print(sess.run(a))
注意:在使用局部变量时必须使用tf.local_variables_initializer初始化器,在使用全局变量时必须使用tf.global_variables_initializer初始化器,不然会报tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value
 

猜你喜欢

转载自blog.csdn.net/u010365819/article/details/88125375