笔记 - 模型训练:正则Loss

前置

  • add_to_collection
import tensorflow as tf

tf.add_to_collection("reg_losses", 1.0)
tf.add_to_collection("reg_losses", 1.0)
loss = tf.get_collection("reg_losses")
with tf.Session() as sess:
    print(loss)

"""
运行结果:
[1.0, 1.0]
"""

添加 正则Loss

  • 手动添加loss到collection
reg_loss = tf.reduce_sum(tf.abs(w))
reg_loss = tf.reduce_sum(tf.square(w))
tf.add_to_collection("reg_losses", reg_loss)
  • 自动添加loss到collection
with tf.contrib.framework.arg_scope(
        [fully_connected],
        weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.01)):
    hidden1 = fully_connected(X, n_hidden1, scope="hidden1"
                              # , weights_regularizer=tf.contrib.layers.l2_regularizer(scale=0.01)
                              )

合并 正则Loss

 # reg_losses = tf.add_n(tf.get_collection("reg_losses"))
 reg_losses = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
 total_loss = tf.add(loss, reg_losses)

猜你喜欢

转载自blog.csdn.net/chen_holy/article/details/91437016