1、定义自己的head
driving\models\dense_heads\shuai_head.py
import torch
from torch import nn
from collections import namedtuple
from mmengine.model import BaseModule
from mmseg.models import HEADS, build_head, build_loss
import sys
sys.path.append("D:/BaiduSyncdisk/SHUAI/")
from models.losses.shuai_loss import * # 注册 ShuaiLoss
@HEADS.register_module()
class ShuaiHead(BaseModule):
def __init__(
self,
loss,
task='RoadCls',
meta_info=None,
aux_annotation=None,
source=None,
num_classes=4,
momentum = 0.01,
epsilon = 1e-3,
in_channels = 448,
out_channels = 1280,
):
super().__init__()
print(" ShuaiHead __init__")
self.task = task,
self.loss = build_loss(loss)
self.meta_info = meta_info,
self.roadcls_head = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(num_features=out_channels, momentum=momentum, eps=epsilon),
nn.ReLU6(inplace=True),
)
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
dropout_rate = 0.3
if dropout_rate > 0:
self.dropout = nn.Dropout(dropout_rate)
else:
self.dropout = None
self.fc = torch.nn.Linear(out_channels, num_classes)
# [1 448 7 7]
def forward(self, x):
outputs = namedtuple("outputs","roadcls_pred")
# [1,1280,7,7]
x = self.roadcls_head(x)
# [1,1280,1,1]
x = self.avgpool(x)
# [1,1280]
x = x.view(x.size(0), -1)
if self.dropout is not None:
x = self.dropout(x)
# [1,4]
x = self.fc(x)
pred_dict = outputs(roadcls_pred=x)
print("ShuaiHead foward:",pred_dict)
return pred_dict
def forward_train(self,
head_args):
"""
Forward call along with loss computation/
"""
input , img_metas, annotations, train_cfg= head_args.values()
target, is_annotation_present, sample_weight = annotations.values()
device_id = input.device
sample_ratio = 1.0
pred_dict = self.forward(input)
x = pred_dict.roadcls_pred
train_loss = self.loss(x,
target,
device_id = device_id,
sample_ratio=sample_ratio)
print("ShuaiHead foward:",train_loss)
return train_loss
看下HEADS
注册表(@HEADS.register_module()
)
- 可以看到ShuaiHead可以被注册到
HEADS
中 - 其实,这里的HEADS是
BACKBONES NECKS HEADS LOSSES SEGMENTORS
的总和
from mmseg.registry import MODELS
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
SEGMENTORS = MODELS
- 看下这的BaseModule,
mmengine\model\base_module.py
2、调用Shuai_head
if __name__ == "__main__":
print("call shuai_head:")
# 1.配置 dict
num_classes = 4
shuai_loss = dict(type='ShuaiLoss',loss_weight=1.0,loss_name='loss_shuai')
head = dict(loss=shuai_loss,type='ShuaiHead',num_classes=num_classes)
# 从注册器中构建
shuai_head = build_head(head)
# 使用shuai head
# 前向传播
input = torch.Tensor(2,448,7,7) # [B,C,H,W]
output = shuai_head(input)
# 前向传播 + Loss计算
target = torch.Tensor(2,num_classes)
annotations = {
"targets":target, "is_annotation_present":True, "sample_weight":1}
head_args = {
"input":input,"img_metas":None, "annotations":annotations,"train_cfg":None}
loss = shuai_head.forward_train(head_args)