Focal Loss 就是一个解决分类问题中类别不平衡、分类难度差异的一个 loss.
Kaiming 大神的 Focal Loss ,二分类形式,是:
如果落实到 ŷ =σ(x) 这个预测,那么就有:
通过一系列调参,得到 α=0.25, γ=2(在他的模型上)的效果最好。
多分类:
Focal Loss 在多分类中的形式也很容易得到,其实就是:
ŷt 是目标的预测值,一般就是经过 softmax 后的结果。那我自己构思的 L∗∗ 怎么推广到多分类?也很简单:
这里 xt 也是目标的预测值,但它是 softmax 前的结果。
tensorlfow实现的multi-class, multi-label 如下:
def focal_loss(self, labels, logits, gamma=2.0, alpha=0.25, normalize=True):
labels = tf.where(labels > 0, tf.ones_like(labels), tf.zeros_like(labels))
labels = tf.cast(labels, tf.float32)
probs = tf.sigmoid(logits)
ce_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
alpha_t = tf.ones_like(logits) * alpha
alpha_t = tf.where(labels > 0, alpha_t, 1.0 - alpha_t)
probs_t = tf.where(labels > 0, probs, 1.0 - probs)
# tf.where(input, a,b),其中a,b均为尺寸一致的tensor,作用是将a中对应input中true的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值
focal_matrix = alpha_t * tf.pow((1.0 - probs_t), gamma)
loss = focal_matrix * ce_loss
loss = tf.reduce_sum(loss)
if normalize:
n_pos = tf.reduce_sum(labels)
# total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix))
# total_weights = tf.Print(total_weights, [n_pos, total_weights])
# loss = loss / total_weights
def has_pos():
return loss / tf.cast(n_pos, tf.float32)
def no_pos():
#total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix))
#return loss / total_weights
return loss
loss = tf.cond(n_pos > 0, has_pos, no_pos)
return loss