参考链接:https://www.jb51.net/article/177667.htm
最近在搭建GRU模型用于预测未来的数据类别,因为数据非常不平衡,我们更关注的1类别的数据很少但是正是我们关注的一个类别,所以考虑用Focal Loss的方法改进数据不平衡造成的效果差。
def compute_class_weights(histogram):
classWeights = np.ones(5, dtype=np.float32)
normHist = histogram / np.sum(histogram)
for i in range(5):
classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
return classWeights
def focal_loss(input, target):
'''
:param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
:param target:
:return:
'''
# n, c, h, w = input.size()
#
# target = target.long()
# inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
# target = target.contiguous().view(-1)
N = input.size(0)
C = input.size(1)
number_0 = torch.sum(target == 0).item()
number_1 = torch.sum(target == 1).item()
number_2 = torch.sum(target == 2).item()
number_3 = torch.sum(target == 3).item()
number_4 = torch.sum(target == 4).item()
frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4), dtype=torch.float32)
frequency = frequency.numpy()
classWeights = compute_class_weights(frequency)
weights = torch.from_numpy(classWeights).float()
weights = weights[target.view(-1)] # 这行代码非常重要
gamma = 2
P = F.softmax(input, dim=1) # shape [num_samples,num_classes]
class_mask = input.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = target.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.) # shape [num_samples,num_classes] one-hot encoding
probs = (P * class_mask).sum(1).view(-1, 1) # shape [num_samples,]
probs = torch.clamp(probs, min=math.exp(-32), max=1.01)
log_p = probs.log()
# print('in calculating batch_loss', weights.shape, probs.shape, log_p.shape)
# batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
batch_loss = -(torch.pow((1 - probs), gamma)) * log_p
loss = batch_loss.mean()
# print('loss: ', loss)
return loss
相比于链接进行了如下改动:
1. 因为链接的input对象是图片,而我的输入数据是二维的数据,所以input size只有N和C。
2. 我计算出来的probs有为0的结果,导致计算log后出现-inf值,因此对probs进行clamp,把0值进行了限定。
probs = torch.clamp(probs, min=math.exp(-32), max=1.01)
后续会对此Focal Loss的代码进行详细注释。