focal loss

Focal Loss 就是一个解决分类问题中类别不平衡、分类难度差异的一个 loss.

Kaiming 大神的 Focal Loss ,二分类形式,是:

VBcD02jFhglbdajMCsZiameIjv6vJgibJl9gRk1yFSQeU66nlwqC856HBGqibtsoyXCKtPeOumoRmdg3PAGLl5vWA

如果落实到 ŷ =σ(x) 这个预测,那么就有:

VBcD02jFhglbdajMCsZiameIjv6vJgibJl1CuI26775Cyp4CibIjKDuPzOOabGwicggdIUCWj3P5y9aeDhA5cAVkCw

通过一系列调参,得到 α=0.25, γ=2(在他的模型上)的效果最好。

多分类:

Focal Loss 在多分类中的形式也很容易得到,其实就是:

VBcD02jFhglbdajMCsZiameIjv6vJgibJlgichcUBg0FibMjoZe7eTaEC11Cj0HVvHicak38mr25ud0SzpMfALtWAwg

ŷt 是目标的预测值,一般就是经过 softmax 后的结果。那我自己构思的 L∗∗ 怎么推广到多分类?也很简单:

VBcD02jFhglbdajMCsZiameIjv6vJgibJlq8dAb8xUUDZxsicConHLjdzxQ37vBoCEtoZEJpjTVXkNLZRSlQSVCQg

这里 xt 也是目标的预测值,但它是 softmax 前的结果。

tensorlfow实现的multi-class, multi-label 如下:

def focal_loss(self, labels, logits, gamma=2.0, alpha=0.25, normalize=True):
        labels = tf.where(labels > 0, tf.ones_like(labels), tf.zeros_like(labels))
        labels = tf.cast(labels, tf.float32)
        probs = tf.sigmoid(logits)
        ce_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)

        alpha_t = tf.ones_like(logits) * alpha
        alpha_t = tf.where(labels > 0, alpha_t, 1.0 - alpha_t)
        probs_t = tf.where(labels > 0, probs, 1.0 - probs)
        # tf.where(input, a,b),其中a,b均为尺寸一致的tensor,作用是将a中对应input中true的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值
        focal_matrix = alpha_t * tf.pow((1.0 - probs_t), gamma)
        loss = focal_matrix * ce_loss

        loss = tf.reduce_sum(loss)
        if normalize:
            n_pos = tf.reduce_sum(labels)
            # total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix))
            # total_weights = tf.Print(total_weights, [n_pos, total_weights])
            #         loss = loss / total_weights
            def has_pos():
                return loss / tf.cast(n_pos, tf.float32)
            def no_pos():
                #total_weights = tf.stop_gradient(tf.reduce_sum(focal_matrix))
                #return loss / total_weights
                return loss
            loss = tf.cond(n_pos > 0, has_pos, no_pos)
        return loss

猜你喜欢

转载自blog.csdn.net/qq_27009517/article/details/83411159