PraNet分割模型搭建

原论文:https://arxiv.org/abs/2006.11392
源码: https://github.com/DengPingFan/PraNet

直接步入正题~~~

一、定义基本卷积模块

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

二、定义RFB模块

class RFB(nn.Module):
    # RFB-like multi-scale module
    def __init__(self, in_channel, out_channel):
        super(RFB, self).__init__()
        self.relu = nn.ReLU(True)
        self.branch0 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
        )
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
        )
        self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
        self.conv_res = BasicConv2d(in_channel, out_channel, 1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))

        x = self.relu(x_cat + self.conv_res(x))
        return x

三、定义aggregation模块

class aggregation(nn.Module):
    # dense aggregation, it can be replaced by other aggregation previous, such as DSS, amulet, and so on.
    # used after MSF
    def __init__(self, channel):
        super(aggregation, self).__init__()
        self.relu = nn.ReLU(True)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)

        self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
        self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
        self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
        self.conv5 = nn.Conv2d(3*channel, 1, 1)

    def forward(self, x1, x2, x3):
        x1_1 = x1
        x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
        x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) * self.conv_upsample3(self.upsample(x2)) * x3

        x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
        x2_2 = self.conv_concat2(x2_2)

        x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
        x3_2 = self.conv_concat3(x3_2)

        x = self.conv4(x3_2)
        x = self.conv5(x)
        return x

四、整体网络结构

class CRANet(nn.Module):
    # resnet based encoder decoder
    def __init__(self, channel=32):
        super(CRANet, self).__init__()
        # ---- ResNet Backbone ----
        self.resnet = ResNet()
        # Receptive Field Block
        self.rfb2_1 = RFB(512, channel)
        self.rfb3_1 = RFB(1024, channel)
        self.rfb4_1 = RFB(2048, channel)
        # Partial Decoder
        self.agg1 = aggregation(channel)
        # self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
        # ---- reverse attention branch 4 ----
        self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1)
        self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2)
        self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2)
        self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2)
        self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1)
        # self.ra4_conv5_up = nn.ConvTranspose2d(1, 1, kernel_size=64, stride=32)
        # ---- reverse attention branch 3 ----
        # self.ra4_3 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2)
        self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1)
        self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
        # self.ra3_conv4_up = nn.ConvTranspose2d(1, 1, kernel_size=32, stride=16)
        # ---- reverse attention branch 2 ----
        # self.ra3_2 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2)
        self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1)
        self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
        # self.ra2_conv4_up = nn.ConvTranspose2d(1, 1, kernel_size=16, stride=8)

        # self.HA = HA()
        if self.training:
            self.initialize_weights()
            # self.apply(CRANet.weights_init)

    def forward(self, x):               # 输入为 bs, 3, 352, 352
        x = self.resnet.conv1(x)        # bs, 64, 176, 176
        x = self.resnet.bn1(x)          # bs, 64, 176, 176
        x = self.resnet.relu(x)         # bs, 64, 176, 176
        x = self.resnet.maxpool(x)      # bs, 64, 88, 88
        x1 = self.resnet.layer1(x)      # bs, 256, 88, 88
        x2 = self.resnet.layer2(x1)     # bs, 512, 44, 44

        x3 = self.resnet.layer3(x2)     # bs, 1024, 22, 22
        x4 = self.resnet.layer4(x3)     # bs, 2048, 11, 11
        x2_rfb = self.rfb2_1(x2)        # channel -> 32  输入通道512,输出通道32
        x3_rfb = self.rfb3_1(x3)        # channel -> 32  输入通道1024,输出通道32
        x4_rfb = self.rfb4_1(x4)        # channel -> 32  输入通道2048,输出通道32

        ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb) # 输入尺寸为11、22、44,输出bs, 1, 44, 44
        lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear')    # Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352)

        # ---- reverse attention branch_4 ----
        crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear') # bs, 1, 11, 11
        x = -1*(torch.sigmoid(crop_4)) + 1  # bs, 1, 11, 11
        x = x.expand(-1, 2048, -1, -1).mul(x4) # bs, 2048, 11, 11
        x = self.ra4_conv1(x) # bs, 256, 11, 11
        x = F.relu(self.ra4_conv2(x)) # bs, 256, 11, 11
        x = F.relu(self.ra4_conv3(x)) # bs, 256, 11, 11
        x = F.relu(self.ra4_conv4(x)) # bs, 256, 11, 11
        ra4_feat = self.ra4_conv5(x) # bs, 1, 11, 11
        x = ra4_feat + crop_4  # bs, 1, 11, 11
        lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear')      # Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)

        # ---- reverse attention branch_3 ----
        # x = F.interpolate(x, scale_factor=2, mode='bilinear')
        crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear') # bs, 1, 22, 22
        x = -1*(torch.sigmoid(crop_3)) + 1 # bs, 1, 22, 22
        x = x.expand(-1, 1024, -1, -1).mul(x3)  # bs, 1024, 22, 22
        x = self.ra3_conv1(x) # bs, 64, 22, 22
        x = F.relu(self.ra3_conv2(x)) # bs, 64, 22, 22
        x = F.relu(self.ra3_conv3(x)) # bs, 64, 22, 22
        ra3_feat = self.ra3_conv4(x) # bs, 1, 22, 22
        x = ra3_feat + crop_3 # bs, 1, 22, 22
        lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear') # Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
        # lateral_map_3 = self.crop(self.ra3_conv4_up(x), x_size)  # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)

        # ---- reverse attention branch_2 ----
        # x = self.ra3_2(x)
        # crop_2 = self.crop(x, x2.size())
        crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear') # bs, 1, 44, 44
        x = -1*(torch.sigmoid(crop_2)) + 1 # bs, 1, 44, 44
        x = x.expand(-1, 512, -1, -1).mul(x2) # bs, 512, 44, 44
        x = self.ra2_conv1(x) # bs, 64, 44, 44
        x = F.relu(self.ra2_conv2(x)) # bs, 64, 44, 44
        x = F.relu(self.ra2_conv3(x)) # bs, 64, 44, 44
        ra2_feat = self.ra2_conv4(x) # bs, 1, 44, 44
        x = ra2_feat + crop_2 # bs, 1, 44, 44
        lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear') # Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
        # lateral_map_2 = self.crop(self.ra2_conv4_up(x), x_size)  # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)

        return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2

猜你喜欢

转载自blog.csdn.net/m0_56247038/article/details/129705064