创建一个对象,实现__call__方法
class weighted_cross_entropy(object):
def __call__(self, y_pred, y_true):
"""
logits: a Tensor with shape [batch_size, image_width, image_height, channel], score from the unet conv10
label: a Tensor with shape [batch_size, image_width, image_height], ground truth
"""
weight = [0.21008659, 0.26289699, 0.28279202, 0.24422441]
# label = tf.one_hot(tf.cast(y_true, dtype=tf.uint8), y_pred.get_shape()[-1])
prob = tf.nn.softmax(y_pred, dim=-1)
loss = -tf.reduce_mean(y_true * tf.log(prob) * weight)
return loss
方法体内的格式可以参照tflearn -> objectives.py来写。注意传入的y_pred和y_true都是float类型的,如上中如果要使用one_hot就需要强转类型。
def categorical_crossentropy(y_pred, y_true):
""" Categorical Crossentropy.
Computes cross entropy between y_pred (logits) and y_true (labels).
Measures the probability error in discrete classification tasks in which
the classes are mutually exclusive (each entry is in exactly one class).
For example, each CIFAR-10 image is labeled with one and only one label:
an image can be a dog or a truck, but not both.
`y_pred` and `y_true` must have the same shape `[batch_size, num_classes]`
and the same dtype (either `float32` or `float64`). It is also required
that `y_true` (labels) are binary arrays (For example, class 2 out of a
total of 5 different classes, will be define as [0., 1., 0., 0., 0.])
Arguments:
y_pred: `Tensor`. Predicted values.
y_true: `Tensor` . Targets (labels), a probability distribution.
"""
with tf.name_scope("Crossentropy"):
y_pred /= tf.reduce_sum(y_pred,
reduction_indices=len(y_pred.get_shape())-1,
keep_dims=True)
# manual computation of crossentropy
y_pred = tf.clip_by_value(y_pred, tf.cast(_EPSILON, dtype=_FLOATX),
tf.cast(1.-_EPSILON, dtype=_FLOATX))
cross_entropy = - tf.reduce_sum(y_true * tf.log(y_pred),
reduction_indices=len(y_pred.get_shape())-1)
return tf.reduce_mean(cross_entropy)
在loss中新建一个对象传进去
network = regression(conv10, optimizer='adam',
loss=weighted_cross_entropy(),
learning_rate=5e-4)
源码:
tflearn -> layers -> estimator.py -> regression
# Building other ops (loss, training ops...)
if isinstance(loss, str):
loss = objectives.get(loss)(incoming, placeholder)
# Check if function
elif hasattr(loss, '__call__'):
try:
loss = loss(incoming, placeholder)
except Exception as e:
print(str(e))
print('Reminder: Custom loss function arguments must be defined as: '
'custom_loss(y_pred, y_true).')
exit()
elif not isinstance(loss, tf.Tensor):
raise ValueError("Invalid Loss type.")
在elif中,如果loss不是字符串,同时有__call__方法,那么就通过调用该方法来实现损失的计算。该方法的作用自行百度。