@Focalloss自我理解
sigmoid交叉熵损失
tensorflow中的 tf.nn.sigmoid_cross_entropy_with_logits()是最开始的基础二分类交叉熵,其中将是该类别的标签为1,不是该类别的标签为0。
one_hot编码表示三分类中的第三个类别为[0,0,1],sigmoid识别为第三个类别为正样本,其他的类别为负样本。
tensorflow的实现是在原有基础的公式做了化简,防止内存溢出,做了约束。源码有详细的公式推导
For brevity, let x = logits
, z = labels
. The logistic loss is
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))
这里主要讲解如何解决正负样本不均衡的sigmoid交叉熵损失
之前看过很多的博客,相比softmax交叉熵中,二分类的交叉熵负样本会远远大于正样本数量。造成分类器对正样本的分类能力较差。则考虑在正样本中加上权重,让损失更偏重与数量少的正样本的分类,tensorflow有一个简易版的接口
tf.nn.weighted_cross_entropy_with_logits() pos_weight指定正样本的权重值。
torch实现的focal_loss
from torch.autograd import Variable
import torch.functional as F
def one_hot(index, classes):
size = index.size() + (classes,)
view = index.size() + (1,)
mask = torch.Tensor(*size).fill_(0)
index = index.view(*view)
ones = 1.
if isinstance(index, Variable):
ones = Variable(torch.Tensor(index.size()).fill_(1))
mask = Variable(mask, volatile=index.volatile)
return mask.scatter_(1, index, ones)
class FocalLoss(nn.Module):
def __init__(self, gamma=0, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = eps
def forward(self, input, target):
y = one_hot(target, input.size(-1))
logit = F.softmax(input)
logit = logit.clamp(self.eps, 1. - self.eps)
loss = -1 * y * torch.log(logit) # cross entropy
loss = loss * (1 - logit) ** self.gamma # focal loss
return loss.sum()