focal_loss

@Focalloss自我理解

sigmoid交叉熵损失

tensorflow中的 tf.nn.sigmoid_cross_entropy_with_logits()是最开始的基础二分类交叉熵,其中将是该类别的标签为1,不是该类别的标签为0。

one_hot编码表示三分类中的第三个类别为[0,0,1],sigmoid识别为第三个类别为正样本,其他的类别为负样本。
tensorflow的实现是在原有基础的公式做了化简,防止内存溢出,做了约束。源码有详细的公式推导

For brevity, let x = logits, z = labels. The logistic loss is

 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
      = (1 - z) * x + log(1 + exp(-x))
      = x - x * z + log(1 + exp(-x))

这里主要讲解如何解决正负样本不均衡的sigmoid交叉熵损失

之前看过很多的博客,相比softmax交叉熵中,二分类的交叉熵负样本会远远大于正样本数量。造成分类器对正样本的分类能力较差。则考虑在正样本中加上权重,让损失更偏重与数量少的正样本的分类,tensorflow有一个简易版的接口

tf.nn.weighted_cross_entropy_with_logits() pos_weight指定正样本的权重值。

torch实现的focal_loss

from torch.autograd import Variable
import torch.functional as F
def one_hot(index, classes):
    size = index.size() + (classes,)
    view = index.size() + (1,)
    mask = torch.Tensor(*size).fill_(0)
    index = index.view(*view)
    ones = 1.
    if isinstance(index, Variable):
        ones = Variable(torch.Tensor(index.size()).fill_(1))
        mask = Variable(mask, volatile=index.volatile)
    return mask.scatter_(1, index, ones)
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
    def forward(self, input, target):
        y = one_hot(target, input.size(-1))
        logit = F.softmax(input)
        logit = logit.clamp(self.eps, 1. - self.eps)
        loss = -1 * y * torch.log(logit) # cross entropy
        loss = loss * (1 - logit) ** self.gamma # focal loss
        return loss.sum()

猜你喜欢

转载自blog.csdn.net/weixin_42662358/article/details/86636277