tensorflow中变量管理reuse参数的使用

TensorFlow用于变量管理的函数主要有两个: tf. get_variable()tf.variable_scope()
前者用于创建和获取变量的值,后者用于生成上下文管理器,创建命名空间,命名空间可以嵌套。

函数tf.get_variable()既可以创建变量,也可以获取变量。控制创建还是获取的开关来自函数tf.variable.scope()中的参数reuse“True”还是"False",分两种情况进行说明:

1. 设置reuse=False时,函数get_variable()表示创建变量

如下面的例子:

with tf.variable_scope("foo",reuse=False):
    v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))

tf.variable_scope()函数中,设置reuse=False时,在其命名空间"foo"中执行函数get_variable()时,表示创建变量"v",若在该命名空间中已经有了变量"v",则在创建时会报错,如下面的例子

import tensorflow as tf

with tf.variable_scope("foo"):
    v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
    v1=tf.get_variable("v",[1])

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-1-eaed46cad84f> in <module>()
      3 with tf.variable_scope("foo"):
      4     v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
----> 5     v1=tf.get_variable("v",[1])
      6 
ValueError: Variable foo/v already exists, disallowed. 
Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? 

2. 设置reuse=True时,函数get_variable()表示获取变量

如下面的例子:

import tensorflow as tf

with tf.variable_scope("foo"):
    v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
    
with tf.variable_scope("foo",reuse=True):
    v1=tf.get_variable("v",[1])

print(v1==v) 
结果为:
True

tf.variable_scope()函数中,设置reuse=True时,在其命名空间"foo"中执行函数get_variable()时,表示获取变量"v",若在该命名空间中还没有该变量,则在获取时会报错,如下面的例子

import tensorflow as tf 

with tf.variable_scope("foo",reuse=True):
    v1=tf.get_variable("v",[1])

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-1-019a05c4b9a4> in <module>()
      2 
      3 with tf.variable_scope("foo",reuse=True):
----> 4     v1=tf.get_variable("v",[1])
      5 

ValueError: Variable foo/v does not exist, or was not created with tf.get_variable(). 
Did you mean to set reuse=tf.AUTO_REUSE in VarScope?

3. 结论

TensorFlow通过tf. get_variable()tf.variable_scope()两个函数,可以创建多个并列的或嵌套的命名空间,用于存储神经网络中的各层的权重、偏置、学习率、滑动平均衰减率、正则化系数等参数值,神经网络不同层的参数可放置在不同的命名空间中。同时,变量重用检错和读取不存在变量检错两种机制保证了数据存放的安全性。

猜你喜欢

转载自blog.csdn.net/johnboat/article/details/84846628