mmsegmentation Add L1Loss

Add the L1Loss definition to the mmseg/models/losses/ module:

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES


@LOSSES.register_module()
class L1Loss(nn.Module):
    # TODO: weight
    def __init__(self, loss_name='loss_l1', **kwargs):
        super(L1Loss, self).__init__()
        self._loss_name = loss_name

    def forward(self, pred, target, weight=None, ignore_index=None):
   		# pred: (n,c,h,w)   target: (n,h,w)
        classes = pred.shape[1]
        size = list(target.shape)
        size.append(classes)  # (n,h,w,c)
        target_one_hot = target.view(-1)  # (n*h*w)
        ones = torch.sparse.torch.eye(classes).to(target_one_hot.device)
        ones = ones.index_select(0, target_one_hot)  # (n*h*w, classes)
        ones = ones.view(*size)  # (n,h,w,c)
        target_one_hot = ones.permute(0, 3, 1, 2)  # (n,c,h,w)
        loss = nn.L1Loss()(pred, target_one_hot)
        return loss

	@property
    def loss_name(self):
        """Loss Name.

        This function must be implemented and will return the name of this
        loss function. This name will be used to combine different loss items
        by simple sum operation. In addition, if you want this loss item to be
        included into the backward graph, `loss_` must be the prefix of the
        name.

        Returns:
            str: The name of this loss item.
        """
        return self._loss_name

Note that the loss_name method must be present, and the returned loss_name needs to loss_be prefixed with.

The shapes of the incoming pred and target are inconsistent and need to be changed to be consistent before nn.L1Loss()the method can be called directly.
pred.shape: (n,c,h,w)
target.shape: (n,h,w)
So the target needs to be converted to one-hot. Conversion to one-hot method: index_select.
(n,h,w) => (n,h,w,c) => (n,c,h,w)

Extended reading: In Pytorch, two ways to turn label into one hot encoding

Guess you like

Origin blog.csdn.net/qq_39735236/article/details/127806133