本章介绍TensorFlow中非常常用的共享变量的使用。其他部分前往:TensorFlow 学习目录
目录
二、摒弃 tf.Variable() 使用 tf.variable_scope() 与 tf.get_variable() 的组合
三、with tf.variable_scope as vs: 对 tf.variable_scope() 的影响
四、使用tf.name_scope,这个命名空间不作用在变量上,而是只作用在OP上面
一、对 tf.Variable() 的讨论
首先我们知道 tf.Variable() 函数的使用方法
import tensorflow as tf
var1 = tf.Variable(tf.constant(0.5), name='var', dtype=tf.float32)
print (var1.name)
输出
var:0
现在思考,如果我想继续再其他的模型中使用这个变量怎么办,比如,GAN网络中的生成器和判别器,如果要是使用 tf.Variable()然后用同样的变量名字,那样会得到一个新的变量而不是我们原先需要。
import tensorflow as tf
var1 = tf.Variable(tf.constant(0.5), name='var', dtype=tf.float32)
print (var1.name)
var2 = tf.Variable(tf.constant(0.5), name='var', dtype=tf.float32)
print (var2.name)
输出
var:0
var_1:0
可以看到,系统直接给了一个新的变量名字,而不是使用之前我们定义的那个 var1。
二、摒弃 tf.Variable() 使用 tf.variable_scope() 与 tf.get_variable() 的组合
- tf.get_variable(),如果变量的名字没有被使用过,那么该语句就是建立一个新的变量,和 tf.Variable() 没有任何的区别,如果变量的名字之前在该“图”中被使用过,那么如果直接使用这个语句会报错,因为一个“图中” tf.get_variable()只能定义同一个名字的变量(样例:code_1),所以如果此时需要共享之前的那个变量,需要配合 tf.variable_scope()(样例:code_2)。
- tf.variable_scope() 相当于一个命名空间,然后可以嵌套使用,可以当作是一个地址,路径之类的东西。
# code_1:
import tensorflow as tf
var3 = tf.get_variable(name='var_', dtype=tf.float32, initializer=tf.constant(3.3333))
print (var3.name)
var4 = tf.get_variable(name='var_', dtype=tf.float32)
print (var4.name)
输出,报错
File "D:/pycodeLIB/TensorFlow/test.py", line 22, in <module>
var4 = tf.get_variable(name='var_', dtype=tf.float32)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1328, in get_variable
constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1090, in get_variable
constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 435, in get_variable
constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 404, in _true_getter
use_resource=use_resource, constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 743, in _get_single_variable
name, "".join(traceback.format_list(tb))))
ValueError: Variable var_ already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 1740, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 3414, in create_op
op_def=op_def)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
下面的代码是正确的使用方法,而且代码中我特意使用了嵌套 tf.variable_scope()的形式。
# code_2
import tensorflow as tf
with tf.variable_scope('test1'):
with tf.variable_scope('test2'):
var3 = tf.get_variable('var1', initializer=tf.constant(value=[1, 2, 3, 4, 5], shape=[5], dtype=tf.float32), dtype=tf.float32)
print (var3.name)
with tf.variable_scope('test1', reuse=True):
with tf.variable_scope('test2'):
var4 = tf.get_variable('var1', dtype=tf.float32)
print (var4.name)
输出结果,可以从结果中看出,此时使用的是用一个参数
test1/test2/var1:0
test1/test2/var1:0
三、with tf.variable_scope as vs: 对 tf.variable_scope() 的影响
如果对一个tf.variable_scope()的嵌套结构的内层 variable_scope初始化为vs那么此时,那么其将不受外部 variable_scope() 的影响,通过比较下面两个代码,来感受一下(该代码和code_2区别在第4行和倒数第3行)
import tensorflow as tf
with tf.variable_scope('test1'):
with tf.variable_scope('test2') as vs:
var3 = tf.get_variable('var1', initializer=tf.constant(value=[1, 2, 3, 4, 5], shape=[5], dtype=tf.float32), dtype=tf.float32)
print (var3.name)
with tf.variable_scope('test1', reuse=True):
with tf.variable_scope(vs):
var4 = tf.get_variable('var1', dtype=tf.float32)
print (var4.name)
输出报错:
Traceback (most recent call last):
File "D:/pycodeLIB/TensorFlow/test.py", line 16, in <module>
var4 = tf.get_variable('var1', dtype=tf.float32)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1328, in get_variable
constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1090, in get_variable
constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 435, in get_variable
constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 404, in _true_getter
use_resource=use_resource, constraint=constraint)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 743, in _get_single_variable
name, "".join(traceback.format_list(tb))))
ValueError: Variable test1/test2/var1 already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 1740, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\ops.py", line 3414, in create_op
op_def=op_def)
File "D:\python3.6.4\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
该错误很明显,是由上面提到的问题的真是反应,就是说,vs不受上一层variable_scope()的影响,所以vs初始化的variable_scope()不可以去共享变量,因为其reuse参数没有设置为True。下面更改之后的代码为
import tensorflow as tf
with tf.variable_scope('test1'):
with tf.variable_scope('test2') as vs:
var3 = tf.get_variable('var1', initializer=tf.constant(value=[1, 2, 3, 4, 5], shape=[5], dtype=tf.float32), dtype=tf.float32)
print (var3.name)
with tf.variable_scope('test1', reuse=True):
with tf.variable_scope(vs, reuse=True):
var4 = tf.get_variable('var1', dtype=tf.float32)
print (var4.name)
输出正确
test1/test2/var1:0
test1/test2/var1:0
四、使用tf.name_scope,这个命名空间不作用在变量上,而是只作用在OP上面
import tensorflow as tf
with tf.variable_scope('test3'):
with tf.name_scope('op_test'):
v = tf.get_variable('v', dtype=tf.float32, initializer=tf.random_normal(shape=[6], mean=0.0, stddev=1.0))
xx = 1.0 + v
print (v.name)
print (xx.op.name)
输出
test3/v:0
test3/op_test/add
可以看出其中的变量没有收到‘op_test'空间的影响,但是OP操作add却受到了'op_test'命名空间的影响。