常用损失函数(二):Dice Loss

        Dice Loss是由Dice系数而得名的,Dice系数是一种用于评估两个样本相似性的度量函数,其值越大意味着这两个样本越相似,Dice系数的数学表达式如下:

Dice=\frac{2|X\cap Y|}{|X|+|Y|}

其中|X\cap Y|表示X和Y之间交集元素的个数,|X||Y|分别表示X、Y中元素的个数。Dice Loss表达式如下:

Dice Loss=1-Dice=1-\frac{2|X\cap Y|}{|X|+|Y|}

1、①Dice Loss常用于语义分割问题中,X表示真实分割图像的像素标签,Y表示模型预测分割图像的像素类别,|X\cap Y|近似为预测图像的像素与真实标签图像的像素之间的点乘,并将点乘结果相加,|X||Y|分别近似为它们各自对应图像中的像素相加。

②对于二分类问题,真实分割标签图像的像素只有0,1两个值,因此|X\cap Y|可以有效地将在预测分割图像中未在真实分割标签图像中激活的所有像素值清零,对于激活的像素,主要是惩罚低置信度的预测,置信度高的预测会得到较高的Dice系数,从而得到较低的Dice Loss。即:

Dice Loss=1-\frac{2\sum_{i=1}^{N}y_{i}\hat{y_{i}}}{\sum_{i=1}^{N}y_{i}+\sum_{i=1}^{N}\hat{y_{i}}}

其中,y_{i}\hat{y_{i}}分别表示像素i的标签值与预测值,N为像素点总个数,等于单张图像的像素个数乘以batchsize。

2、可以说Dice Loss是直接优化F1 score而来的,是对F1 score的高度抽象,可用于多分类分割问题上。其中查准率(精确率)公式如下,表示在预测为1的样本中实际为1的概率:

P=\frac{TP}{TP+FP}

查全率(召回率)的公式如下,表示在实际为1的样本中预测为1的概率:

R=\frac{TP}{TP+FN}

查准率和查全率往往是相互制约的,如果提高模型的查准率,就会降低模型的查全率;提高模型的查全率就会降低模型的查准率。为了平衡这两者的关系,F1 score就被提出,其公式如下:

F1 score=\frac{2PR}{P+R}=\frac{2TP}{2TP+FP+FN}

在二分类问题中,Dice系数也可以写成Dice=\frac{2TP}{2TP+FP+FN}=F1 score

3、Dice Loss可以缓解样本中前景背景(面积)不平衡带来的消极影响,前景背景不平衡也就是说图像中大部分区域是不包含目标的,只有一小部分区域包含目标。Dice Loss训练更关注对前景区域的挖掘,即保证有较低的FN,但会存在损失饱和问题,而CE Loss是平等地计算每个像素点的损失,当前点的损失只和当前预测值与真实标签值的距离有关,这会导致一些问题(见Focal Loss)。因此单独使用Dice Loss往往并不能取得较好的结果,需要进行组合使用,比如Dice Loss+CE Loss或者Dice Loss+Focal Loss等。

4、Dice Loss的代码实现如下:

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss

猜你喜欢

转载自blog.csdn.net/Mike_honor/article/details/125871091