【mmsegmentation】Loss模块详解(入门)以调用FocalLoss为例

1、mmdet中损失函数模块简介

1.1. Loss的注册器

先来看段代码:mmseg/models/builder.py


# mmseg/registry/registry.py
# mangage all kinds of modules inheriting `nn.Module`
# MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models'])

from mmseg.registry import MODELS

BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS # 损失
SEGMENTORS = MODELS

这里MODELS注册器同时赋予给了其他模块。
再看看mmseg\models_init_.py

from .assigners import *  # noqa: F401,F403
from .backbones import *  # noqa: F401,F403

from .data_preprocessor import SegDataPreProcessor
from .decode_heads import *  # noqa: F401,F403
from .losses import *  # noqa: F401,F403
from .necks import *  # noqa: F401,F403
from .segmentors import *  # noqa: F401,F403
from .text_encoder import *  # noqa: F401,F403


from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
                      build_head, build_loss, build_segmentor)

__all__ = [
    'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
    'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
]

# build_mtl_SHUAI
1.2. 注册FocalLoss()

models\losses\focal_loss.py
在这里插入图片描述

上述初始化参数比较简单,就两个参数:init():部分主要关注gamma和alpha两个参数,forward()部分主要关注pred和target两个参数。
举个实际例子算一下:

import torch
from mmseg.models import build_loss

# 配置dict
loss_bbox = dict(type='FocalLoss',
                 use_sigmoid=True,
                 gamma=2.0,
                 alpha=0.5,
                 reduction='mean',
                 class_weight=None,
                 loss_weight=1.0,
                 loss_name='loss_focal')

# 从注册器中构建
focal_loss = build_loss(loss_bbox)

# 使用focal loss
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
loss = focal_loss(pred, target)
print("loss:",loss)

在这里插入图片描述
在这里插入图片描述

1.3. 总结

基本上mmseg所有损失的计算流程就上述过程,在使用Focal Loss时,不必关心那么多超参,直接build loss然后传入pred和target即可,其余参数基本默认即可。

猜你喜欢

转载自blog.csdn.net/m0_51579041/article/details/142681474