YOLOX改进之损失函数修改(下)

文章内容:如何在YOLOX官网代码中修改–定位损失

环境:pytorch1.8

修改内容

(1)置信度预测损失更换:二元交叉熵损失替换为FocalLoss或者VariFocalLoss

(2)定位损失更换:IOU损失替换为GIOUCIOUEIOU以及a-IOU系列

提示:使用之前可以先了解YOLOX及上述损失函数原理

参考链接

YOLOX官网链接:https://github.com/Megvii-BaseDetection/YOLOX

YOLOX原理解析(Bubbliiiing大佬版):https://blog.csdn.net/weixin_44791964/article/details/120476949

FocalLoss损失解析:https://cyqhn.blog.csdn.net/article/details/87343004

VariFocalLoss损失解析:https://blog.csdn.net/weixin_42096202/article/details/108567189

GIOU、CIOU、DIOU、EIOU等:https://blog.csdn.net/neil3611244/article/details/113794197

a-IOU:https://blog.csdn.net/wjytbest/article/details/121513560

使用方法:直接替换即可

代码修改过程

1、IOUloss等其他系列更改

修改位置:只需要在YOLOX-main/yolox/models/losses.py中更改,如“loss_type=ciou”

class IOUloss(nn.Module):
    def __init__(self, reduction="none", loss_type="diou"):
        super(IOUloss, self).__init__()
        self.reduction = reduction
        self.loss_type = loss_type

    def forward(self, pred, target):  #(xi.yi,w,h)
        assert pred.shape[0] == target.shape[0]

        pred = pred.view(-1, 4)
        target = target.view(-1, 4)
        tl = torch.max(
            (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
        )
        br = torch.min(
            (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
        )

        area_p = torch.prod(pred[:, 2:], 1)
        area_g = torch.prod(target[:, 2:], 1)

        en = (tl < br).type(tl.type()).prod(dim=1)
        area_i = torch.prod(br - tl, 1) * en
        area_u = area_p + area_g - area_i
        iou = (area_i) / (area_u + 1e-7)

        if self.loss_type == "iou":
            loss = 1 - iou ** 2   ###############   2>>>>3(a-iou)
        elif self.loss_type == "giou":
            c_tl = torch.min(
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
            )
            c_br = torch.max(
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
            )
            area_c = torch.prod(c_br - c_tl, 1)
            giou = iou - (area_c - area_u) / area_c.clamp(1e-7)
            loss = 1 - giou.clamp(min=-1.0, max=1.0) ####################**2

        elif self.loss_type == "diou":
            c_tl = torch.min(
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)  # 包围框的左上点
            )
            c_br = torch.max(
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)  # 包围框的右下点
            )
            convex_dis = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + torch.pow(c_br[:, 1]-c_tl[:, 1], 2) + 1e-7 # convex diagonal squared

            center_dis = (torch.pow(pred[:, 0]-target[:, 0], 2) + torch.pow(pred[:, 1]-target[:,1], 2))  # center diagonal squared

            diou = iou - (center_dis / convex_dis)
            loss = 1 - diou.clamp(min=-1.0, max=1.0)

        elif self.loss_type == "ciou":
            c_tl = torch.min(
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
            )
            c_br = torch.max(
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
            )
            convex_dis = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + torch.pow(c_br[:, 1]-c_tl[:,1], 2) + 1e-7  # convex diagonal squared

            center_dis = (torch.pow(pred[:, 0]-target[:, 0], 2) + torch.pow(pred[:, 1]-target[:,1], 2))  # center diagonal squared

            v = (4 / math.pi ** 2) * torch.pow(torch.atan(target[:, 2] / torch.clamp(target[:, 3], min = 1e-7)) - 
                                                torch.atan(pred[:, 2] / torch.clamp(pred[:, 3], min = 1e-7)), 2)

            with torch.no_grad():
                alpha = v / ((1 + 1e-7) - iou + v)
            
            ciou = iou - (center_dis / convex_dis + alpha * v)
            
            loss = 1 - ciou.clamp(min=-1.0, max=1.0)
        elif self.loss_type == "eiou":

            c_tl = torch.min(
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
            )
            c_br = torch.max(
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
            )
            convex_dis = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + torch.pow(c_br[:, 1]-c_tl[:,1], 2)  + 1e-7 # convex diagonal squared

            center_dis = (torch.pow(pred[:, 0]-target[:, 0], 2) + torch.pow(pred[:, 1]-target[:,1], 2))  # center diagonal squared

            dis_w = torch.pow(pred[:, 2]-target[:, 2], 2)  # 两个框的w欧式距离
            dis_h = torch.pow(pred[:, 3]-target[:, 3], 2)  # 两个框的h欧式距离

            C_w = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + 1e-7 # 包围框的w平方
            C_h = torch.pow(c_br[:, 1]-c_tl[:, 1], 2) + 1e-7 # 包围框的h平方

            eiou = iou - (center_dis / convex_dis) - (dis_w / C_w) - (dis_h / C_h)

            loss = 1 - eiou.clamp(min=-1.0, max=1.0)

        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()

        return loss

2、a-IOU系列修改(其实a-iou与iou就是多了个a次方)

注意:为了不影响原代码,重新在losses.py中续写一个类似IOUloss的alpha_IOUloss

步骤一:在losses.py中续写一个alpha_IOUloss的类(这里只有IOU+GIOU+CIOU,可以参考IOUloss自行补上)

class alpha_IOUloss(nn.Module):
    def __init__(self, reduction="none", loss_type="ciou",alpha=3):
        super(alpha_IOUloss, self).__init__()
        self.reduction = reduction
        self.loss_type = loss_type

        self.alpha = alpha

    def forward(self, pred, target):
        assert pred.shape[0] == target.shape[0]

        pred = pred.view(-1, 4)
        target = target.view(-1, 4)
        tl = torch.max(
            (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
        )
        br = torch.min(
            (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
        )

        area_p = torch.prod(pred[:, 2:], 1)
        area_g = torch.prod(target[:, 2:], 1)

        en = (tl < br).type(tl.type()).prod(dim=1)
        area_i = torch.prod(br - tl, 1) * en
        area_u = area_p + area_g - area_i
        iou = (area_i) / (area_u + 1e-7)

        if self.loss_type == "iou":
            loss = 1 - iou ** self.alpha   ###############   2>>>>3(a-iou)
        elif self.loss_type == "giou":
            c_tl = torch.min(
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
            )
            c_br = torch.max(
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
            )
            area_c = torch.prod(c_br - c_tl, 1)
            giou = iou** self.alpha - ((area_c - area_u) / area_c.clamp(1e-16))** self.alpha
            loss = 1 - giou.clamp(min=-1.0, max=1.0)


        elif self.loss_type == "ciou":
            c_tl = torch.min(
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
            )
            c_br = torch.max(
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
            )
            convex_dis = torch.pow(c_br[:, 0]-c_tl[:, 0], 2) + torch.pow(c_br[:, 1]-c_tl[:,1], 2) + 1e-7 # convex diagonal squared

            center_dis = (torch.pow(pred[:, 0]-target[:, 0], 2) + torch.pow(pred[:, 1]-target[:,1], 2)) # center diagonal squared

            v = (4 / math.pi ** 2) * torch.pow(torch.atan(target[:, 2] / torch.clamp(target[:, 3], min = 1e-7)) - 
                                                torch.atan(pred[:, 2] / torch.clamp(pred[:, 3], min = 1e-7)), 2)
            with torch.no_grad():
                beat = v / (v - iou + 1)
            
            ciou = iou** self.alpha - (center_dis** self.alpha / convex_dis** self.alpha + (beat * v)** self.alpha)
            
            loss = 1 - ciou.clamp(min=-1.0, max=1.0)

        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()

        return loss

步骤二:在yolo_head.py中追加调用alpha_IOUloss

from .losses import IOUloss, alpha_IOUloss

步骤三:实例化alpha_IOUloss,并注释之前的IOUloss

#self.iou_loss = IOUloss(reduction="none")
self.iou_loss = alpha_IOUloss(reduction="none")

效果:根据个人数据集而定。不同场景不同效果,对我的数据集基本上都没有原来的好,但也差距不大,降了一点点。

以上代码链接
链接:https://pan.baidu.com/s/1XQAvT2VdtMMLoUo4FY9v8A
提取码:hd9f

猜你喜欢

转载自blog.csdn.net/weixin_45679938/article/details/122347520