[Tensorflow]L2正则化和collection【tf.GraphKeys】

L2-Regularization 实现的话,需要把所有的参数放在一个集合内,最后计算loss时,再减去加权值。

相比自己乱搞,代码一团糟,Tensorflow 提供了更优美的实现方法。

一、tf.GraphKeys : 多个包含Variables(Tensor)集合

 (1)GLOBAL_VARIABLES:使用tf.get_variable()时,默认会将vairable放入这个集合。

   我们熟悉的tf.global_variables_initializer()就是初始化这个集合内的Variables。

import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer())
b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer())
#collections=None等价于 collection=[tf.GraphKeys.GLOBAL_VARIABLES]

gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)          #tf.get_collection(collection_name)返回某个collection的列表
for var in gv: 
  print(var is a)
  print(var.get_shape())

        

   Tips: tf.GraphKeys.GLOBAL_VARIABLES == "variable"。即其保存的是一个字符串。

(2)自定义集合

   想个集合的名字,然后在tf.get_variable时,把集合名字传给 collection 就好了。

import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",shape=[10],collections=["mycollection"])  #不把GLOBAL_VARIABLES加进去,那么就不在那个集合里了。
keys=tf.get_collection("mycollection")
for key in keys:
  print(key.name)

二、L2正则化

先看看tf.contrib.layers.l2_regularizer(weight_decay)都执行了什么:

import tensorflow as tf
sess=tf.Session()
weight_decay=0.1
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
"""
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) 
"""
#**上面代码的等价代码
a=tf.get_variable("I_am_a",initializer=tmp)
a2=tf.reduce_sum(a*a)*weight_decay/2;
a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)
#**
sess.run(tf.global_variables_initializer())
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:
  print("%s : %s" %(key.name,sess.run(key)))

我们很容易可以模拟出tf.contrib.layers.l2_regularizer都做了什么,不过会让代码变丑。
以下比较完整实现L2 正则化。


import tensorflow as tf
sess=tf.Session()
weight_decay=0.1                                                #(1)定义weight_decay
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)           #(2)定义l2_regularizer()
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)  #(3)创建variable,l2_regularizer复制给regularizer参数。
                                                                #目测REXXX_LOSSES集合
#regularizer定义会将a加入REGULARIZATION_LOSSES集合
print("Global Set:")
keys = tf.get_collection("variables")
for key in keys:
  print(key.name)
print("Regular Set:")
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:
  print(key.name)
print("--------------------")
sess.run(tf.global_variables_initializer())
print(sess.run(a))
reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)   #(4)则REGULARIAZTION_LOSSES集合会包含所有被weight_decay后的参数和,将其相加
l2_loss=tf.add_n(reg_set)
print("loss=%s" %(sess.run(l2_loss)))
"""
此处输出0.7,即:
   weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7
其实代码自己写也很方便,用API看着比较正规。
在网络模型中,直接将l2_loss加入loss就好了。(loss变大,执行train自然会decay)
"""


猜你喜欢

转载自blog.csdn.net/vcvycy/article/details/78597350