U2Net 源码解析

U2Net 源码解析

在这里插入图片描述

如上图所示, U2Net是由ReSidual U-blocks块拼接而成的U形网络,类似UNetEncoder-Decoder架构, 每个ReSidual U-blocks块内部本身又是一个U形网络,且加入了残差块,接下来先不用把关注点放在具体的RSU内部是如何实现的,先按数据流向以及X的shape变化走一遍网络,以认识到模型是如何运作的,之后会详细解释模型内部细节。

Encoder

输入 X(B, N, H, W) 先经RSU7将通道从 N(默认为3通道) 卷机到64并进行下采样至(B, 64, H/2, W/2),接着经过 RSU6( -> downSample) -> RSU5( -> downSample) -> RSU4( -> downSample) -> RSU4F( -> downSample) -> RSU4F 最终将 X 卷机成 En_6 (B, 512, H/32, W/32)

Decoder

En_6同En_5拼接后经过RSU4F以及上采样操作后变为 De_5(B, 512, H/16, W/16), 类似的,
De_5同En_5拼接后经RSU4并上采样变为De_4(B, 256, H/8, W/8),
De_4同En_4拼接后经RSU5并上采样变为De_3(B, 128, H/4, W/4),
De_3同En_3拼接后经RSU6并上采样变为De_2(B, 64, H/2, W/2),
De_2同En_2拼接后经RSU7并上采样变为De_1(B, 64, H, W),

需要注意的是, De_1(B, 64, H, W)并不是最终的结果, U2NET模型参考了FPN的思想, 会将Decoder阶段的每一个输出(De_1 - De_5, En_6)都Conv2d(any, num_classes)并上采样到(B, num_classes, H, W), 最终将这些临时输出融合到一起在进行一次Conv2d(num_classes * 6, num_classes) 输出最终的结果, 具体细节如下:

En_6会经过Conv2d(512, num_classes)并进行上采样到Sup6(B, num_classes, H, W),
De_5会经过Conv2d(512, num_classes)并上采样到Sup5(B, num_classes, H, W),
De_4会经过Conv2d(256, num_classes)并上采样到Sup4(B, num_classes, H, W),
De_3会经过Conv2d(128, num_classes)并上采样到Sup3(B, num_classes, H, W),
De_2会经过Conv2d(64, num_classes)并上采样到Sup2(B, num_classes, H, W),
De_1会经过Conv2d(64, num_classes)生成Sup1(B, num_classes, H, W)

最后将(Sup1, Sup2, Sup3, Sup4, Sup5, Sup6)拼接后, 经Conv2d(num_classes * 6, num_classes)卷机成 output (B, num_classes, H, W)的shape进而输出

如上流程对应的源代码如下:

class U2Net(nn.Layer):
    """
    The U^2-Net implementation based on PaddlePaddle.

    The original article refers to
    Xuebin Qin, et, al. "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection"
    (https://arxiv.org/abs/2005.09007).

    Args:
        num_classes (int): The unique number of target classes.
        in_ch (int, optional): Input channels. Default: 3.
        pretrained (str, optional): The path or url of pretrained model for fine tuning. Default: None.

    """

    def __init__(self, num_classes, in_ch=3, pretrained=None):
        super(U2Net, self).__init__()

        self.stage1 = RSU7(in_ch, 32, 64)
        self.pool12 = nn.MaxPool2D(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(64, 32, 128)
        self.pool23 = nn.MaxPool2D(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(128, 64, 256)
        self.pool34 = nn.MaxPool2D(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(256, 128, 512)
        self.pool45 = nn.MaxPool2D(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(512, 256, 512)
        self.pool56 = nn.MaxPool2D(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(512, 256, 512)

        # decoder
        self.stage5d = RSU4F(1024, 256, 512)
        self.stage4d = RSU4(1024, 128, 256)
        self.stage3d = RSU5(512, 64, 128)
        self.stage2d = RSU6(256, 32, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2D(64, num_classes, 3, padding=1)
        self.side2 = nn.Conv2D(64, num_classes, 3, padding=1)
        self.side3 = nn.Conv2D(128, num_classes, 3, padding=1)
        self.side4 = nn.Conv2D(256, num_classes, 3, padding=1)
        self.side5 = nn.Conv2D(512, num_classes, 3, padding=1)
        self.side6 = nn.Conv2D(512, num_classes, 3, padding=1)

        self.outconv = nn.Conv2D(6 * num_classes, num_classes, 1)

        self.pretrained = pretrained
        self.init_weight()

    def forward(self, x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(paddle.concat((hx6up, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(paddle.concat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(paddle.concat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(paddle.concat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.stage1d(paddle.concat((hx2dup, hx1), 1))

        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(paddle.concat((d1, d2, d3, d4, d5, d6), 1))

        return [d0, d1, d2, d3, d4, d5, d6]

    def init_weight(self):
        if self.pretrained is not None:
            utils.load_entire_model(self, self.pretrained)

接下来就如上代码中所涉及的具体细节做讲解:

RSU

RSU4, RSU5, RSU6, RSU7 内部细节差不多,就RSU4说明即可

在这里插入图片描述

扫描二维码关注公众号,回复: 15660242 查看本文章

上图是RSU4的数据流向以及shape变换图,其中REBNCONV操作是进行H,W不变的卷积,并且经过BatchNorm2D以及Relu激活函数,

class REBNCONV(nn.Layer):
    def __init__(self, in_ch=3, out_ch=3, dirate=1):
        super(REBNCONV, self).__init__()

        self.conv_s1 = nn.Conv2D(
            in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
        self.bn_s1 = nn.BatchNorm2D(out_ch)
        self.relu_s1 = nn.ReLU()

    def forward(self, x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout

尺寸变卷积说明:

在使用nn.Conv2D时,引入了dilation之后,整个的new_kernelsize就变成了:

n e w _ k e r n e l s i z e = d i l a t i o n ∗ ( k e r n e l s i z e − 1 ) + 1 new\_kernelsize = dilation * (kernelsize - 1) + 1 new_kernelsize=dilation(kernelsize1)+1

d o u t = ( d i n − n e w _ k e r n e l s i z e + 2 ∗ p a d d i n g ) s t r i d e + 1 d_{out} = \frac{(d_{in} - new\_kernelsize + 2 * padding)}{stride} + 1 dout=stride(dinnew_kernelsize+2padding)+1

如上源代码所示,padding 和 dilation 同步乘以 dirate, 本网络中, stride是1, kernize是3, 那么结果就变为

d _ o u t = ( d i n − d i r a t e ∗ 2 − 1 + 2 ∗ d i r a t e ) + 1 = d _ i n d\_out = {(d_{in} - dirate * 2 - 1 + 2 * dirate)} + 1 = d\_in d_out=(dindirate21+2dirate)+1=d_in

对上图做个简要概括:RSU内部会进行类似UNET的操作,经过一系列的尺寸不变的卷机以及MaxPool2D(stride=2)的操作,中间夹杂一点空洞卷积(REBNCONV(mid_ch, mid_ch, dirate=4)),正向卷机完后,再逆向拼接卷积以及上采样,最后将Decoder后的结果((B, out_ch, H, W)) 和 输入X(B, in_ch, H, W)经过一次REBNCONV操作后变为(B, out_ch, H, W)的结果进行残差连接后输出, 完整代码如下:

class RSU4(nn.Layer):  #UNet04DRES(nn.Layer):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2D(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2D(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(paddle.concat((hx4, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(paddle.concat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(paddle.concat((hx2dup, hx1), 1))

        return hx1d + hxin

RSU4F

在这里插入图片描述

RSU4F模块的核心思想是:将输入张量从in通道卷积成out通道,然后out卷积成mid,接着midmid之间使用dialation为 2, 4, 8分别进行卷机,下图示例为输入(32, 32, 512) , in_ch = 512, mid_ch = 256, out_ch = 512 的RSU4F示例, 整个过程中 size 不变, 类似UNET, 最后输出(32, 32, 512), RSU4F相比于RSU4的区别是, RSU4F整个过程中, 输入X的尺寸(H, W)不会进行改变,还有就是空洞卷积数量较多,其它都大差不差、

class RSU4F(nn.Layer):  #UNet04FRES(nn.Layer):
    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) # 3
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) # 5
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) # 9

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) # 17

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(paddle.concat((hx4, hx3), 1))
        hx2d = self.rebnconv2d(paddle.concat((hx3d, hx2), 1))
        hx1d = self.rebnconv1d(paddle.concat((hx2d, hx1), 1))

        return hx1d + hxin

参考链接:

padde版源码

U2Net讲解

画图工具

猜你喜欢

转载自blog.csdn.net/qq_29304033/article/details/125802601