Focal Loss实现

参考链接:https://www.jb51.net/article/177667.htm 

最近在搭建GRU模型用于预测未来的数据类别,因为数据非常不平衡,我们更关注的1类别的数据很少但是正是我们关注的一个类别,所以考虑用Focal Loss的方法改进数据不平衡造成的效果差。

def compute_class_weights(histogram):
  classWeights = np.ones(5, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(5):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights

def focal_loss(input, target):
    '''
    :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
    :param target:
    :return:
    '''
    # n, c, h, w = input.size()
    #
    # target = target.long()
    # inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    # target = target.contiguous().view(-1)
    N = input.size(0)
    C = input.size(1)

    number_0 = torch.sum(target == 0).item()
    number_1 = torch.sum(target == 1).item()
    number_2 = torch.sum(target == 2).item()
    number_3 = torch.sum(target == 3).item()
    number_4 = torch.sum(target == 4).item()

    frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4), dtype=torch.float32)
    frequency = frequency.numpy()
    classWeights = compute_class_weights(frequency)
    weights = torch.from_numpy(classWeights).float()
    weights = weights[target.view(-1)]  # 这行代码非常重要

    gamma = 2

    P = F.softmax(input, dim=1)  # shape [num_samples,num_classes]
    class_mask = input.data.new(N, C).fill_(0)
    class_mask = Variable(class_mask)
    ids = target.view(-1, 1)
    class_mask.scatter_(1, ids.data, 1.)  # shape [num_samples,num_classes] one-hot encoding

    probs = (P * class_mask).sum(1).view(-1, 1)  # shape [num_samples,]
    probs = torch.clamp(probs, min=math.exp(-32), max=1.01)
    log_p = probs.log()
    # print('in calculating batch_loss', weights.shape, probs.shape, log_p.shape)

    # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
    batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

    loss = batch_loss.mean()
    # print('loss: ', loss)
    return loss

相比于链接进行了如下改动:

1. 因为链接的input对象是图片,而我的输入数据是二维的数据,所以input size只有N和C。

2. 我计算出来的probs有为0的结果,导致计算log后出现-inf值,因此对probs进行clamp,把0值进行了限定。

probs = torch.clamp(probs, min=math.exp(-32), max=1.01)

后续会对此Focal Loss的代码进行详细注释。

猜你喜欢

转载自blog.csdn.net/weixin_39915444/article/details/127277414