[BlendMask]BlendMask: Top-Down Meets Bottom-Up for Instance Segmentation代码笔记

BlendMask: Top-Down Meets Bottom-Up for Instance Segmentation

如果对你有帮助的话,希望给我点个赞~


blendmask是根据anchor free的FCOS目标检测网络扩展应用到实例分割领域。
总的执行顺序为 backbone fpn and resnet --> fcos --> blendmask.py --> basis_module.py --> blend.py

其中关于FCOS部分的代码笔记,见我的另一篇 FCOS代码笔记

BlendMask 网路结构:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

BlendMask的FCOS 分支新增的top_feat结构代码

def losses(self, logits_pred, reg_pred, ctrness_pred, locations, gt_instances, top_feats=None):
        """
        Return the losses from a set of FCOS predictions and their associated ground-truth.

        Returns:
            dict[loss name -> loss value]: A dict mapping from loss name to loss value.
        """
        '''
        省略其他相同的内容
        '''
        if len(top_feats) > 0: # blendmask 
            instances.top_feats = cat([
                # Reshape: (N, -1, Hi, Wi) -> (N*Hi*Wi, -1)   [784, -1]
                x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) for x in top_feats
            ], dim=0,)\
                
        '''
        in blendmask:
            top_feats[0].size()
        torch.Size([2, 784, 96, 148])
            top_feats[1].size()
        torch.Size([2, 784, 48, 74])
            top_feats[2].size()
        torch.Size([2, 784, 24, 37])
            top_feats[3].size()
        torch.Size([2, 784, 12, 19])
            top_feats[4].size()
        torch.Size([2, 784, 6, 10])
        '''
        # instances.top_feats.size() [37872, 784]  在接下来的fcos_losses(self, instances)函数中会继续筛选,最后只剩下[instances, 784]的大小。
        # 这就是attention的矩阵方法:
        # 每一行有784个特征。784代表又784个channel,而37872代表了hw * batchsize的大小.
        # 说白了就把二维的图像h*w平铺成了1维度hw
        pdb.set_trace()

1. AdelaiDet/adet/modeling/blendmask/blendmask.py

# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
from torch import nn

from detectron2.structures import ImageList
from detectron2.modeling.postprocessing import detector_postprocess, sem_seg_postprocess
from detectron2.modeling.proposal_generator import build_proposal_generator
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.meta_arch.panoptic_fpn import combine_semantic_and_instance_outputs
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.semantic_seg import build_sem_seg_head

from .blender import build_blender
from .basis_module import build_basis_module
import pdb
__all__ = ["BlendMask"]


@META_ARCH_REGISTRY.register()
class BlendMask(nn.Module):
    """
    Main class for BlendMask architectures (see https://arxiv.org/abd/1901.02446).
    """

    def __init__(self, cfg):
        super().__init__()

        self.device = torch.device(cfg.MODEL.DEVICE)
        self.instance_loss_weight = cfg.MODEL.BLENDMASK.INSTANCE_LOSS_WEIGHT # 1.0
        self.backbone = build_backbone(cfg) # build_fcos_resnet_fpn_backbone
        pdb.set_trace()
        self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) # FCOS
        pdb.set_trace()
        self.blender = build_blender(cfg) # blender
        pdb.set_trace()
        self.basis_module = build_basis_module(cfg, self.backbone.output_shape()) # basis_module
        pdb.set_trace()

        # options when combining instance & semantic outputs
        self.combine_on = cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED # FALSE
        if self.combine_on: 
            self.panoptic_module = build_sem_seg_head(cfg, self.backbone.output_shape())
            self.combine_overlap_threshold = cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH
            self.combine_stuff_area_limit = cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT
            self.combine_instances_confidence_threshold = (
                cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH)

        # build top module
        in_channels = cfg.MODEL.FPN.OUT_CHANNELS # 256
        num_bases = cfg.MODEL.BASIS_MODULE.NUM_BASES # 4
        attn_size = cfg.MODEL.BLENDMASK.ATTN_SIZE # 14
        attn_len = num_bases * attn_size * attn_size # K*M*M =  784
        self.top_layer = nn.Conv2d(
            in_channels, attn_len,
            kernel_size=3, stride=1, padding=1)
        # self.top_layer Conv2d(256, 784, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        torch.nn.init.normal_(self.top_layer.weight, std=0.01)
        torch.nn.init.constant_(self.top_layer.bias, 0)

        pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
        pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
        self.normalizer = lambda x: (x - pixel_mean) / pixel_std
        self.to(self.device)
        pdb.set_trace()
    def forward(self, batched_inputs): # blendmask训练时,先进入blendmask.forward()方法
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.

        For now, each item in the list is a dict that contains:
            image: Tensor, image in (C, H, W) format.
            instances: Instances
            sem_seg: semantic segmentation ground truth.
            Other information that's included in the original dicts, such as:
                "height", "width" (int): the output resolution of the model, used in inference.
                    See :meth:`postprocess` for details.

        Returns:
            list[dict]: each dict is the results for one image. The dict
                contains the following keys:
                "instances": see :meth:`GeneralizedRCNN.forward` for its format.
                "sem_seg": see :meth:`SemanticSegmentor.forward` for its format.
                "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`.
                    See the return value of
                    :func:`combine_semantic_and_instance_outputs` for its format.
        """
        images = [x["image"].to(self.device) for x in batched_inputs] # 把batch_inputs中的images放入cuda
        images = [self.normalizer(x) for x in images] # 正则化
        images = ImageList.from_tensors(images, self.backbone.size_divisibility) #   self.backbone.size_divisibility 32
        features = self.backbone(images.tensor) # resnet-fpn forward 
        pdb.set_trace()
        if self.combine_on: # False
            if "sem_seg" in batched_inputs[0]:
                gt_sem = [x["sem_seg"].to(self.device) for x in batched_inputs]
                gt_sem = ImageList.from_tensors(
                    gt_sem, self.backbone.size_divisibility, self.panoptic_module.ignore_value
                ).tensor
            else:
                gt_sem = None
            sem_seg_results, sem_seg_losses = self.panoptic_module(features, gt_sem)

        if "basis_sem" in batched_inputs[0]: # True  [1273, 768], batched_inputs[0].keys(): dict_keys(['file_name', 'height', 'width', 'image_id', 'image', 'instances', 'basis_sem'])
            basis_sem = [x["basis_sem"].to(self.device) for x in batched_inputs]
            basis_sem = ImageList.from_tensors(
                basis_sem, self.backbone.size_divisibility, 0).tensor
        else:
            basis_sem = None
        basis_out, basis_losses = self.basis_module(features, basis_sem) # resnet-fpn出来的特征经过basis_module basis_losses是语义辅助损失, basis_out是经过了refine分支 -->tower分支后的特征
        pdb.set_trace()
        if "instances" in batched_inputs[0]: # True
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs] # len(gt_instances) = batch size
        else:
            gt_instances = None
        # 对应fcos_outputs.py 的444行 self.top_layer不参与fcos原本的分支以及loss计算,只是多加了一个维度的变换。 256 --> 784
        proposals, proposal_losses = self.proposal_generator(  #  新加了self.top_layer --> fcos.forward()
            images, features, gt_instances, self.top_layer)
        pdb.set_trace()
        detector_results, detector_losses = self.blender( # 调用了__call__方法   
            basis_out["bases"], proposals, gt_instances)
        pdb.set_trace()

        if self.training:
            losses = {
    
    }
            losses.update(basis_losses) # 语义辅助损失
            losses.update({
    
    k: v * self.instance_loss_weight for k, v in detector_losses.items()})
            losses.update(proposal_losses)
            if self.combine_on: # False
                losses.update(sem_seg_losses)
            return losses
        '''
        (Pdb) losses
        {
        'loss_basis_sem': tensor(1.3058, device='cuda:0', grad_fn=<MulBackward0>), 
        'loss_mask': tensor(0.6931, device='cuda:0', grad_fn=<MulBackward0>), 
        'loss_fcos_cls': tensor(1.1881, device='cuda:0', grad_fn=<DivBackward0>), 
        'loss_fcos_loc': tensor(0.9733, device='cuda:0', grad_fn=<DivBackward0>), 
        'loss_fcos_ctr': tensor(0.7431, device='cuda:0', grad_fn=<DivBackward0>)
        }


        '''
        processed_results = []
        pdb.set_trace()
        for i, (detector_result, input_per_image, image_size) in enumerate(zip(
                detector_results, batched_inputs, images.image_sizes)):
            height = input_per_image.get("height", image_size[0])
            width = input_per_image.get("width", image_size[1])
            detector_r = detector_postprocess(detector_result, height, width)
            processed_result = {
    
    "instances": detector_r}
            if self.combine_on:
                sem_seg_r = sem_seg_postprocess(
                    sem_seg_results[i], image_size, height, width)
                processed_result["sem_seg"] = sem_seg_r
            if "seg_thing_out" in basis_out:
                seg_thing_r = sem_seg_postprocess(
                    basis_out["seg_thing_out"], image_size, height, width)
                processed_result["sem_thing_seg"] = seg_thing_r
            if self.basis_module.visualize:
                processed_result["bases"] = basis_out["bases"]
            processed_results.append(processed_result)

            if self.combine_on:
                panoptic_r = combine_semantic_and_instance_outputs(
                    detector_r,
                    sem_seg_r.argmax(dim=0),
                    self.combine_overlap_threshold,
                    self.combine_stuff_area_limit,
                    self.combine_instances_confidence_threshold)
                processed_results[-1]["panoptic_seg"] = panoptic_r
            pdb.set_trace()
        pdb.set_trace()
        return processed_results

'''
self.basis_module
    ProtoNet(
        (refine): ModuleList(
            (0): Sequential(
            (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
            (1): Sequential(
            (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
            (2): Sequential(
            (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
        )
        (tower): Sequential(
            (0): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
            (1): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
            (2): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
            (3): Upsample(scale_factor=2.0, mode=bilinear)
            (4): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
            (5): Conv2d(128, 4, kernel_size=(1, 1), stride=(1, 1))
        )
        (seg_head): Sequential(
            (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
            (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (5): ReLU()
            (6): Conv2d(128, 81, kernel_size=(1, 1), stride=(1, 1))
        )
    )

'''



'''
(Pdb) batched_inputs[0]
{'file_name': '/hdd2/wh/datasets/coco/train2017/000000522935.jpg', 'height': 480, 'width': 640, 'image_id': 522935, 'image': tensor([[[255, 254, 253,  ..., 251, 253, 254],
         [254, 253, 253,  ..., 253, 254, 254],
         [252, 252, 253,  ..., 255, 255, 255],
         ...,
         [254, 254, 254,  ..., 253, 254, 254],
         [253, 254, 255,  ..., 254, 255, 255],
         [253, 254, 255,  ..., 255, 255, 255]],

        [[255, 254, 253,  ..., 251, 253, 254],
         [254, 253, 253,  ..., 253, 254, 254],
         [252, 252, 253,  ..., 255, 255, 255],
         ...,
         [254, 254, 254,  ..., 253, 254, 254],
         [254, 254, 255,  ..., 254, 254, 254],
         [254, 254, 255,  ..., 255, 254, 254]],

        [[255, 254, 253,  ..., 251, 253, 254],
         [254, 253, 253,  ..., 253, 254, 254],
         [252, 252, 253,  ..., 255, 255, 255],
         ...,
         [254, 254, 254,  ..., 253, 254, 254],
         [253, 253, 255,  ..., 254, 254, 253],
         [252, 253, 255,  ..., 255, 254, 253]]], dtype=torch.uint8), 'instances': Instances(num_instances=7, image_height=768, image_width=1024, fields=[gt_boxes: Boxes(tensor([[ 77.8720, 118.6560, 825.0880, 709.4400],
        [375.2480, 249.3760, 440.5120, 345.9040],
        [358.9760, 115.6320, 724.8480, 667.9040],
        [136.2240, 487.0880, 309.0400, 559.6639],
        [376.3360, 512.4800, 571.8560, 644.8800],
        [265.7760, 136.3360, 828.4000, 711.0400],
        [ 86.0960, 120.4480, 483.8720, 712.8160]])), gt_classes: tensor([57, 67,  0, 73, 73, 56, 56]), gt_masks: PolygonMasks(num_instances=7)]), 'basis_sem': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])}

'''
'''
features 
    (Pdb) features['p3'].size()
    torch.Size([2, 256, 88, 128])
    (Pdb) features['p4'].size()
    torch.Size([2, 256, 44, 64])
    (Pdb) features['p5'].size()
    torch.Size([2, 256, 22, 32])
    (Pdb) features['p6'].size()
    torch.Size([2, 256, 11, 16])
    (Pdb) features['p7'].size()
    torch.Size([2, 256, 6, 8])

'''

'''
(Pdb) basis_losses
{'loss_basis_sem': tensor(1.3058, device='cuda:0', grad_fn=<MulBackward0>)}

proposal_losses
{'loss_fcos_cls': tensor(1.1881, device='cuda:0', grad_fn=<DivBackward0>), 'loss_fcos_loc': tensor(0.9733, device='cuda:0', grad_fn=<DivBackward0>), 'loss_fcos_ctr': tensor(0.7431, device='cuda:0', grad_fn=<DivBackward0>)}

detector_losses
{'loss_mask': tensor(0.6931, device='cuda:0', grad_fn=<DivBackward0>)}

'''

2. AdelaiDet/adet/modeling/blendmask/blender.py

import torch
from torch.nn import functional as F

from detectron2.layers import cat
from detectron2.modeling.poolers import ROIPooler

import pdb
def build_blender(cfg):
    return Blender(cfg)


class Blender(object):
    def __init__(self, cfg):

        # fmt: off
        self.pooler_resolution = cfg.MODEL.BLENDMASK.BOTTOM_RESOLUTION #56
        sampling_ratio         = cfg.MODEL.BLENDMASK.POOLER_SAMPLING_RATIO #1
        pooler_type            = cfg.MODEL.BLENDMASK.POOLER_TYPE  # 'ROIAlignV2'
        pooler_scales          = cfg.MODEL.BLENDMASK.POOLER_SCALES  # (0.25,)
        self.attn_size         = cfg.MODEL.BLENDMASK.ATTN_SIZE # 14
        self.top_interp        = cfg.MODEL.BLENDMASK.TOP_INTERP # 'bililnear'
        num_bases              = cfg.MODEL.BASIS_MODULE.NUM_BASES # 4
        # fmt: on
        
        self.attn_len = num_bases * self.attn_size * self.attn_size # 4 * 14 * 14

        self.pooler = ROIPooler(
            output_size=self.pooler_resolution, # 56
            scales=pooler_scales, # 0.25
            sampling_ratio=sampling_ratio, # 1
            pooler_type=pooler_type, # ROIAlignV2
            canonical_level=2)
        pdb.set_trace()
    '''
    ROIPooler(
            (level_poolers): ModuleList(
                (0): ROIAlign(output_size=(56, 56), spatial_scale=0.25, sampling_ratio=1, aligned=True)
            )
        )

    '''
    def __call__(self, bases, proposals, gt_instances):
        if gt_instances is not None:
            # training
            # reshape attns
            dense_info = proposals["instances"] # 254个 instances
            attns = dense_info.top_feats # [instances, 784]
            pos_inds = dense_info.pos_inds # [instances] 正样本的数量 pos_ind表示所有FPN层的像素点加起来的某些正样本的点
            if pos_inds.numel() == 0:
                return None, {
    
    "loss_mask": sum([x.sum() * 0 for x in attns]) + bases[0].sum() * 0}

            gt_inds = dense_info.gt_inds # [254] gt_inds 对应 pos_inds位置上的类别 见下方注释。
            # len(gt_instances) =  2 表示batch_size个图片上的gt实例(gt_instances[0] =4 gt_instances = 1)    

            # 1.对应paper eq(1) rd = ROIPOOLrxr(B, Pd) d = [1,D]
            rois = self.pooler(bases, [x.gt_boxes for x in gt_instances]) # torch.Size([5, 4, 56, 56]) ROIPooler的forward方法(), 见dt2的源码。torch.Size([instances, 4, 56, 56]) 
            rois = rois[gt_inds] # torch.Size([49, 4, 56, 56])]  根据gt_inds上的值 代表着索引到原来roi[i]上[N,....]的i的值,进行复制。
            pdb.set_trace()
            pred_mask_logits = self.merge_bases(rois, attns) # [49, 56*56]

            # gen targets
            gt_masks = []
            for instances_per_image in gt_instances:
                if len(instances_per_image.gt_boxes.tensor) == 0:
                    continue
                gt_mask_per_image = instances_per_image.gt_masks.crop_and_resize( # crop到 56 * 56
                    instances_per_image.gt_boxes.tensor, self.pooler_resolution
                ).to(device=pred_mask_logits.device) # gt_mask_per_image.size() --> [4(instances), 56, 56] bool类型的
                gt_masks.append(gt_mask_per_image)
            gt_masks = cat(gt_masks, dim=0) # [5, 56, 56]
            gt_masks = gt_masks[gt_inds] # [49, 56, 56]
            N = gt_masks.size(0) # 49
            gt_masks = gt_masks.view(N, -1) # [49, 3136]

            gt_ctr = dense_info.gt_ctrs # [49]
            loss_denorm = proposals["loss_denorm"]  # loss_denorm: ctrness_targets.sum()
            # mask BCE loss
            mask_losses = F.binary_cross_entropy_with_logits(  # [49, 3136]
                pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="none") # 为什么这里的reduction不是'sum'呢?
            mask_loss = ((mask_losses.mean(dim=-1) * gt_ctr).sum() # 求mask loss也会
                         / loss_denorm)
            pdb.set_trace()
            return None, {
    
    "loss_mask": mask_loss}
        else:
            # no proposals
            total_instances = sum([len(x) for x in proposals])
            if total_instances == 0:
                # add empty pred_masks results
                for box in proposals:
                    box.pred_masks = box.pred_classes.view(
                        -1, 1, self.pooler_resolution, self.pooler_resolution)
                return proposals, {
    
    }
            rois = self.pooler(bases, [x.pred_boxes for x in proposals])
            attns = cat([x.top_feat for x in proposals], dim=0)
            pred_mask_logits = self.merge_bases(rois, attns).sigmoid()
            pred_mask_logits = pred_mask_logits.view(
                -1, 1, self.pooler_resolution, self.pooler_resolution)
            start_ind = 0
            for box in proposals:
                end_ind = start_ind + len(box)
                box.pred_masks = pred_mask_logits[start_ind:end_ind]
                start_ind = end_ind
            pdb.set_trace()
            return proposals, {
    
    }

    def merge_bases(self, rois, coeffs, location_to_inds=None):
        # merge predictions
        # coeffs [N, 784] rois [N, 4, 56, 56] 此处的N代表的是gt_inds的个数
        N = coeffs.size(0)
        pdb.set_trace()
        if location_to_inds is not None: # NONE
            rois = rois[location_to_inds]
        N, B, H, W = rois.size()

        coeffs = coeffs.view(N, -1, self.attn_size, self.attn_size) # [N, -1, M, M] --> [N, 4, 14, 14])
        
        # 2. 对应paper eq(2)和eq(3) a'd = interpolate_(M x M) --> (R x R)(a_d)  Sd = softmax(a'd)
        coeffs = F.interpolate(coeffs, (H, W),
                               mode=self.top_interp).softmax(dim=1) # S_d = softmax(a'_d) 在通道上对每一个元素做softmax。此处也就是对于4个元素。  # [N, 4, 14, 14] --> [N, 4, 56, 56]

        # 3. 对应paper eq(4) md = \sum (s^k_d * r^k_d) 
        masks_preds = (rois * coeffs).sum(dim=1)  # torch.Size([N, 56, 56])

        pdb.set_trace()
        return masks_preds.view(N, -1) # [N, 56 * 56]


'''
(Pdb) dense_info.top_feats.size()
torch.Size([254, 784])
(Pdb) dense_info.pos_inds
tensor([ 8690,  8691,  8692,  8838,  8839,  8840,  8986,  8987,  8988, 10978,
        10979, 10980, 11126, 11127, 11128, 11274, 11275, 11276, 14219, 14220,
        14221, 14367, 14368, 14369, 14515, 14516, 14517, 14522, 14523, 14652,
        14653, 14670, 14671, 14800, 14801, 14818, 14819, 14948, 14949, 15103,
        15251, 15399, 15703, 15704, 15705, 15851, 15852, 15853, 17181, 17182,
        17183, 17329, 17330, 17331, 17477, 17478, 17479, 17785, 17786, 17933,
        17934, 17936, 17937, 18081, 18082, 18084, 18085, 18232, 18233, 18950,
        18951, 18952, 19098, 19099, 19100, 19246, 19247, 19248, 19703, 19705,
        19844, 19845, 19846, 19851, 19853, 19995, 19999, 20001, 20143, 20291,
        21157, 21158, 21159, 21305, 21306, 21307, 21453, 21454, 21455, 21988,
        21989, 21990, 22136, 22137, 22138, 22284, 22285, 22286, 22315, 22316,
        22317, 22463, 22464, 22465, 22611, 22612, 22613, 23019, 23020, 23021,
        30533, 30534, 30535, 31092, 31093, 31094, 32271, 32272, 32273, 32345,
        32347, 32419, 32420, 32421, 32640, 32641, 32642, 33031, 33032, 33033,
        33105, 33106, 33107, 33178, 33179, 33180, 33181, 33252, 33253, 33254,
        33305, 33306, 33307, 33379, 33380, 33381, 33453, 33454, 33455, 33666,
        33667, 33712, 33713, 33714, 33786, 33787, 33788, 33796, 33797, 33798,
        33860, 33861, 33862, 33870, 33871, 33872, 33944, 33945, 33946, 36023,
        36024, 36025, 36097, 36098, 36099, 36151, 36152, 36153, 36188, 36189,
        36190, 36225, 36226, 36227, 36493, 36494, 36495, 36530, 36531, 36532,
        36567, 36568, 36569, 36679, 36680, 36681, 36753, 36754, 36755, 36871,
        36872, 36873, 36908, 36909, 36910, 36945, 36946, 36947, 37029, 37030,
        37031, 37066, 37067, 37068, 37103, 37104, 37105, 37379, 37380, 37381,
        37398, 37399, 37400, 37417, 37418, 37419, 37644, 37645, 37646, 37659,
        37660, 37661, 37663, 37664, 37665, 37678, 37679, 37680, 37682, 37683,
        37684, 37697, 37698, 37699], device='cuda:0')
(Pdb) dense_info.pos_inds.size()
torch.Size([254])
(Pdb)  dense_info.gt_inds
tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        23, 23, 23, 23, 23, 23, 23, 23, 23, 27, 27, 25, 25, 27, 27, 25, 25, 27,
        27, 25, 25, 26, 26, 26, 15, 15, 15, 15, 15, 15, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 19, 19, 19, 19, 18, 18, 19, 19, 18, 18, 18, 18, 20, 20, 20,
        20, 20, 20, 20, 20, 20, 30, 17, 16, 16, 16, 30, 17, 24, 30, 17, 24, 24,
        22, 22, 22, 22, 22, 22, 22, 22, 22, 21, 21, 21, 21, 21, 21, 21, 21, 21,
        29, 29, 29, 29, 29, 29, 29, 29, 29, 13, 13, 13,  2,  2,  2,  3,  3,  3,
        15, 15, 15, 15, 15, 15, 15, 15, 14, 14, 14,  6,  6,  6,  6,  6,  6, 11,
         6,  6,  6, 11, 11, 11, 16, 16, 16, 16, 16, 16, 16, 16, 16, 22, 22,  8,
         8,  8,  8,  8,  8,  9,  9,  9,  8,  8,  8,  9,  9,  9,  9,  9,  9,  2,
         2,  2,  2,  2,  2,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,
         5,  5,  5,  5,  5, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 28, 28, 28, 28, 28, 28, 28, 28, 28,  1,  1,  1,  1,  1,  1,  1,
         1,  1, 10, 10, 10,  7,  7,  7, 10, 10, 10,  7,  7,  7, 10, 10, 10,  7,
         7,  7], device='cuda:0')
(Pdb)  dense_info.gt_inds.size()
torch.Size([254])

'''

'''
(Pdb) gt_inds
tensor([1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2,
        2], device='cuda:0')
(Pdb) proposals['instances'][40].gt_inds
tensor([4], device='cuda:0')
(Pdb) proposals['instances'][40].labels
tensor([23], device='cuda:0')
(Pdb) proposals['instances'][41].labels
tensor([23], device='cuda:0')
(Pdb) proposals['instances'][4].labels
tensor([0], device='cuda:0')
(Pdb) proposals['instances'][42].labels
tensor([23], device='cuda:0')
(Pdb) proposals['instances'][43].labels
tensor([23], device='cuda:0')
(Pdb) proposals['instances'][44].labels
tensor([23], device='cuda:0')
(Pdb) proposals['instances'][45].labels
tensor([7], device='cuda:0')
(Pdb)  rois[gt_inds].size()
torch.Size([49, 4, 56, 56])
(Pdb)  self.pooler(bases, [x.gt_boxes for x in gt_instances]) .size()
torch.Size([5, 4, 56, 56])
(Pdb) gt_instances
[Instances(num_instances=4, image_height=981, image_width=736, fields=[gt_boxes: Boxes(tensor([[  9.7827, 800.9252,  34.3773, 858.4823],
        [ 71.1773, 800.0361,  84.4100, 833.3441],
        [207.5367,  37.8758, 531.0547, 879.5432],
        [438.8247, 763.7545, 633.2820, 933.7280]], device='cuda:0')), gt_classes: tensor([0, 0, 7, 7], device='cuda:0'), gt_masks: PolygonMasks(num_instances=4)]), Instances(num_instances=1, image_height=939, image_width=704, fields=[gt_boxes: Boxes(tensor([[ 50.1013, 113.9418, 678.6853, 827.1710]], device='cuda:0')), gt_classes: tensor([23], device='cuda:0'), gt_masks: PolygonMasks(num_instances=1)])]
(Pdb) rois.size()
torch.Size([49, 4, 56, 56])
(Pdb) self.pooler(bases, [x.gt_boxes for x in gt_instances]).size()
torch.Size([5, 4, 56, 56])

'''

3. AdelaiDet/adet/modeling/blendmask/basis_module.py

from typing import Dict
from torch import nn
from torch.nn import functional as F

from detectron2.utils.registry import Registry
from detectron2.layers import ShapeSpec

from adet.layers import conv_with_kaiming_uniform


BASIS_MODULE_REGISTRY = Registry("BASIS_MODULE")
BASIS_MODULE_REGISTRY.__doc__ = """
Registry for basis module, which produces global bases from feature maps.

The registered object will be called with `obj(cfg, input_shape)`.
The call should return a `nn.Module` object.
"""

import pdb
def build_basis_module(cfg, input_shape):
    name = cfg.MODEL.BASIS_MODULE.NAME # ProtoNet
    return BASIS_MODULE_REGISTRY.get(name)(cfg, input_shape)


@BASIS_MODULE_REGISTRY.register()
class ProtoNet(nn.Module):
    def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):

        #input_shape:
        # {
    
    
        #   'p3': ShapeSpec(channels=256, height=None, width=None, stride=8), 
        #   'p4': ShapeSpec(channels=256, height=None, width=None, stride=16), 
        #   'p5': ShapeSpec(channels=256, height=None, width=None, stride=32), 
        #   'p6': ShapeSpec(channels=256, height=None, width=None, stride=64), 
        #   'p7': ShapeSpec(channels=256, height=None, width=None, stride=128)
        # }

        """
        TODO: support deconv and variable channel width
        """
        # official protonet has a relu after each conv
        super().__init__()
        # fmt: off
        mask_dim          = cfg.MODEL.BASIS_MODULE.NUM_BASES # 4
        planes            = cfg.MODEL.BASIS_MODULE.CONVS_DIM # 128
        self.in_features  = cfg.MODEL.BASIS_MODULE.IN_FEATURES # ["p3", "p4", "p5"]
        self.loss_on      = cfg.MODEL.BASIS_MODULE.LOSS_ON # True
        norm              = cfg.MODEL.BASIS_MODULE.NORM # SyncBN
        num_convs         = cfg.MODEL.BASIS_MODULE.NUM_CONVS # 3
        self.visualize    = cfg.MODEL.BLENDMASK.VISUALIZE
        # fmt: on

        feature_channels = {
    
    k: v.channels for k, v in input_shape.items()} # {'p3': 256, 'p4': 256, 'p5': 256, 'p6': 256, 'p7': 256}


        conv_block = conv_with_kaiming_uniform(norm, True)  # conv relu bn
        self.refine = nn.ModuleList()
        for in_feature in self.in_features:
            self.refine.append(conv_block(
                feature_channels[in_feature], planes, 3, 1))
        tower = []
        for i in range(num_convs):
            tower.append(
                conv_block(planes, planes, 3, 1))
        tower.append(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False))
        tower.append(
            conv_block(planes, planes, 3, 1))
        tower.append(
            nn.Conv2d(planes, mask_dim, 1))
        self.add_module('tower', nn.Sequential(*tower))

        if self.loss_on:
            # fmt: off
            self.common_stride   = cfg.MODEL.BASIS_MODULE.COMMON_STRIDE  # 8
            num_classes          = cfg.MODEL.BASIS_MODULE.NUM_CLASSES + 1  # 81
            self.sem_loss_weight = cfg.MODEL.BASIS_MODULE.LOSS_WEIGHT # 0.3
            # fmt: on

            inplanes = feature_channels[self.in_features[0]] # 256
            self.seg_head = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=3,
                                                    stride=1, padding=1, bias=False),
                                          nn.BatchNorm2d(planes),
                                          nn.ReLU(),
                                          nn.Conv2d(planes, planes, kernel_size=3,
                                                    stride=1, padding=1, bias=False),
                                          nn.BatchNorm2d(planes),
                                          nn.ReLU(),
                                          nn.Conv2d(planes, num_classes, kernel_size=1,
                                                    stride=1))
        pdb.set_trace()

        '''
        tower
            [
                Sequential(
                    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): ReLU(inplace=True)
                ),
                Sequential(
                    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): ReLU(inplace=True)
                ), 
                Sequential(
                    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): ReLU(inplace=True)
                ), 
                Upsample(scale_factor=2.0, mode=bilinear), 
                Sequential(
                    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): ReLU(inplace=True)
                ), 
                Conv2d(128, 4, kernel_size=(1, 1), stride=(1, 1))
            ]

        '''

        '''
        seg_head
            Sequential(
                    (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): ReLU()
                    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (5): ReLU()
                    (6): Conv2d(128, 81, kernel_size=(1, 1), stride=(1, 1))
                )

        '''


        '''
        self.refine
            ModuleList(
            (0): Sequential(
                (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
            )
            (1): Sequential(
                (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
            )
            (2): Sequential(
                (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                (1): NaiveSyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
            )
            )

        '''
    def forward(self, features, targets=None): # len(features) = 5  targets [N, image_h ,image_w] targets是先验知识
        for i, f in enumerate(self.in_features): # self.in_features: ['p3', 'p4', 'p5']
            if i == 0:
                x = self.refine[i](features[f])
            else:
                x_p = self.refine[i](features[f])
                x_p = F.interpolate(x_p, x.size()[2:], mode="bilinear", align_corners=False)
                # x_p = aligned_bilinear(x_p, x.size(3) // x_p.size(3))
                x = x + x_p
                pdb.set_trace() # x [2, 128, 96, 148]
        outputs = {
    
    "bases": [self.tower(x)]} #  outputs.keys() --> dict_keys['bases'] outputs['bases'][0].size() --> [2, 4, 192, 296] 有一个upsample 2倍
        losses = {
    
    }
        # auxiliary thing semantic loss 辅助语义损失
        if self.training and self.loss_on: # True
            sem_out = self.seg_head(features[self.in_features[0]]) # features['p3'] sem_out.size() --> [2, 81, 160, 96]
            # resize target to reduce memory
            gt_sem = targets.unsqueeze(1).float() # gt_sem :[2, 1, 1280, 768]
            gt_sem = F.interpolate(
                gt_sem, scale_factor=1 / self.common_stride) # self.common_stride 8  缩小后: gt_sem.size() -->  [2, 1, 160, 96]
            
            seg_loss = F.cross_entropy(
                sem_out, gt_sem.squeeze(1).long()) # [2, 96, 148]
            losses['loss_basis_sem'] = seg_loss * self.sem_loss_weight # 辅助的语义损失 self.sem_loss_weight 0.3
        elif self.visualize and hasattr(self, "seg_head"):
            outputs["seg_thing_out"] = self.seg_head(features[self.in_features[0]])
        pdb.set_trace()
        return outputs, losses

猜你喜欢

转载自blog.csdn.net/weixin_43823854/article/details/109967482
今日推荐