原论文:https://arxiv.org/abs/2006.11392
源码: https://github.com/DengPingFan/PraNet
直接步入正题~~~
一、定义基本卷积模块
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
二、定义RFB模块
class RFB(nn.Module):
# RFB-like multi-scale module
def __init__(self, in_channel, out_channel):
super(RFB, self).__init__()
self.relu = nn.ReLU(True)
self.branch0 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
)
self.branch1 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
)
self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
self.conv_res = BasicConv2d(in_channel, out_channel, 1)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
x = self.relu(x_cat + self.conv_res(x))
return x
三、定义aggregation模块
class aggregation(nn.Module):
# dense aggregation, it can be replaced by other aggregation previous, such as DSS, amulet, and so on.
# used after MSF
def __init__(self, channel):
super(aggregation, self).__init__()
self.relu = nn.ReLU(True)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
self.conv5 = nn.Conv2d(3*channel, 1, 1)
def forward(self, x1, x2, x3):
x1_1 = x1
x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) * self.conv_upsample3(self.upsample(x2)) * x3
x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
x2_2 = self.conv_concat2(x2_2)
x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
x3_2 = self.conv_concat3(x3_2)
x = self.conv4(x3_2)
x = self.conv5(x)
return x
四、整体网络结构
class CRANet(nn.Module):
# resnet based encoder decoder
def __init__(self, channel=32):
super(CRANet, self).__init__()
# ---- ResNet Backbone ----
self.resnet = ResNet()
# Receptive Field Block
self.rfb2_1 = RFB(512, channel)
self.rfb3_1 = RFB(1024, channel)
self.rfb4_1 = RFB(2048, channel)
# Partial Decoder
self.agg1 = aggregation(channel)
# self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
# ---- reverse attention branch 4 ----
self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1)
self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2)
self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2)
self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2)
self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1)
# self.ra4_conv5_up = nn.ConvTranspose2d(1, 1, kernel_size=64, stride=32)
# ---- reverse attention branch 3 ----
# self.ra4_3 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2)
self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1)
self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
# self.ra3_conv4_up = nn.ConvTranspose2d(1, 1, kernel_size=32, stride=16)
# ---- reverse attention branch 2 ----
# self.ra3_2 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2)
self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1)
self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
# self.ra2_conv4_up = nn.ConvTranspose2d(1, 1, kernel_size=16, stride=8)
# self.HA = HA()
if self.training:
self.initialize_weights()
# self.apply(CRANet.weights_init)
def forward(self, x): # 输入为 bs, 3, 352, 352
x = self.resnet.conv1(x) # bs, 64, 176, 176
x = self.resnet.bn1(x) # bs, 64, 176, 176
x = self.resnet.relu(x) # bs, 64, 176, 176
x = self.resnet.maxpool(x) # bs, 64, 88, 88
x1 = self.resnet.layer1(x) # bs, 256, 88, 88
x2 = self.resnet.layer2(x1) # bs, 512, 44, 44
x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22
x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11
x2_rfb = self.rfb2_1(x2) # channel -> 32 输入通道512,输出通道32
x3_rfb = self.rfb3_1(x3) # channel -> 32 输入通道1024,输出通道32
x4_rfb = self.rfb4_1(x4) # channel -> 32 输入通道2048,输出通道32
ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb) # 输入尺寸为11、22、44,输出bs, 1, 44, 44
lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear') # Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
# ---- reverse attention branch_4 ----
crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear') # bs, 1, 11, 11
x = -1*(torch.sigmoid(crop_4)) + 1 # bs, 1, 11, 11
x = x.expand(-1, 2048, -1, -1).mul(x4) # bs, 2048, 11, 11
x = self.ra4_conv1(x) # bs, 256, 11, 11
x = F.relu(self.ra4_conv2(x)) # bs, 256, 11, 11
x = F.relu(self.ra4_conv3(x)) # bs, 256, 11, 11
x = F.relu(self.ra4_conv4(x)) # bs, 256, 11, 11
ra4_feat = self.ra4_conv5(x) # bs, 1, 11, 11
x = ra4_feat + crop_4 # bs, 1, 11, 11
lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear') # Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)
# ---- reverse attention branch_3 ----
# x = F.interpolate(x, scale_factor=2, mode='bilinear')
crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear') # bs, 1, 22, 22
x = -1*(torch.sigmoid(crop_3)) + 1 # bs, 1, 22, 22
x = x.expand(-1, 1024, -1, -1).mul(x3) # bs, 1024, 22, 22
x = self.ra3_conv1(x) # bs, 64, 22, 22
x = F.relu(self.ra3_conv2(x)) # bs, 64, 22, 22
x = F.relu(self.ra3_conv3(x)) # bs, 64, 22, 22
ra3_feat = self.ra3_conv4(x) # bs, 1, 22, 22
x = ra3_feat + crop_3 # bs, 1, 22, 22
lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear') # Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
# lateral_map_3 = self.crop(self.ra3_conv4_up(x), x_size) # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)
# ---- reverse attention branch_2 ----
# x = self.ra3_2(x)
# crop_2 = self.crop(x, x2.size())
crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear') # bs, 1, 44, 44
x = -1*(torch.sigmoid(crop_2)) + 1 # bs, 1, 44, 44
x = x.expand(-1, 512, -1, -1).mul(x2) # bs, 512, 44, 44
x = self.ra2_conv1(x) # bs, 64, 44, 44
x = F.relu(self.ra2_conv2(x)) # bs, 64, 44, 44
x = F.relu(self.ra2_conv3(x)) # bs, 64, 44, 44
ra2_feat = self.ra2_conv4(x) # bs, 1, 44, 44
x = ra2_feat + crop_2 # bs, 1, 44, 44
lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear') # Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
# lateral_map_2 = self.crop(self.ra2_conv4_up(x), x_size) # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)
return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2