pytorch 实现one-hot focal_loss

class Focalloss(torch.nn.Module):
def init(self,gamma=1,alpha=None,reduce=“mean”):
super().init()
self.gamma = gamma
self.alpha = alpha
self.reduce = reduce
def forward(self,input,target):
pre = torch.sigmoid(input)
loss = -(1-pre)self.gammatargettorch.log(pre)-preself.gamma*(1- target)torch.log(1-pre)
if self.alpha:
loss=loss
self.alpha
if self.reduce=“mean”:
return torch.mean(loss)
if self.reduce=“sum”:
return torch.sum(loss)

发布了36 篇原创文章 · 获赞 1 · 访问量 6384

猜你喜欢

转载自blog.csdn.net/qq_34291583/article/details/101049496