【图像分割】Unet详解

Unet结构

Unet结果如下图所示:大致可以分为三部分,卷积部分、下采样部分、上采样部分。
在这里插入图片描述
卷积 如图中蓝色部分表示卷积。该部分包括两个3*3的卷积层,每个卷积层后面有一个RELU。该部分主要是对图像的大小进行调整(为什么调整)
代码实现如下:

class Conv_Block(nn.Module):
    def __init__(self, in_channel,out_channel):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel,out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
        )

    def forward(self,x):
        return self.layer(x)

下取样

左侧是进行一步步的向下采样,这部分主要是特征提取,也可以理解为编码器,输出通道数是输入通道数的double。
1、通道数代表什么意思?
通道数可以理解为特征图的数目,每一个通道对应一个特种也对应一个特征图。
2、通道数是怎么变化的?
例如图中的第一次向下采样,输入通道数由64变为128。即在卷积的时候有128组卷积核,这64个通道分别与每一组卷积核做卷积,这样每一组卷积核就对应一个通道,128组卷积核就对应128个通道。
这64个通道分别与每一组卷积核做卷积这一句的意思是64个通道对应的64张特征图与每一组卷积核做完卷积之后生成的特征图要叠加在一起,所以每一组卷积核只生成一张特征图。
理解了上面,我们就可以轻松的理解向下采样的意义就是提取特征
3、下采样采取的特征是累加的还是在浅层样本的基础之上进行深一步的采样?
下采样的过程是在浅层样本的基础之上进行深一步的采样,并不是说深层的特征图包含浅层的特征图。
4、那都已经采样到第四层了,前三层的样本就真的用不到了吗?上采样的时候不应该也包含了前三层的样本吗?
答:首先要明确下采样和上采样的目的是什么:是要生成一个分割图像!不是还原原图。 下采样过程中第一层只包含第一层的信息,第二层只包含第二层的信息,第三层只包含第三层的信心,第四层只包含第四层的信息(就是第n层只有第n层的信息,不包含前面n-1层的信息)。这样的话上采样的第一步输入信息只有下采样的第四特征图,但是如果直接利用第四特征图进行还原的话会失真,因为第四特征图包含的特征信息很多(第四特征图有1024,但是第一特征图只有64,相比之下很多),所以在上采样的时候用到了跳跃连接,这样第四特征图参考第三特征图进行上采样,以此类推,直到采样结束,这样子输出的分割图像比较贴合真实图像。
其实前三层的特征信息通过跳跃连接用来做对照,主要用到的还是第四层的信息。
代码实现如下:

class DownSample(nn.Module):
    def __init__(self,channel):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(channel,channel, 3, 2, 1, padding_mode="reflect", bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )

    def forward(self,x):
        return self.layer(x)

上取样

上取样与下取样类似,只不过是输出的通道数变为输入通道数的一半。
上采样是为了特征融合
代码实现如下:

class UpSample(nn.Module):
    def __init__(self,channel):
        super(UpSample, self).__init__()
        self.layer = nn.Conv2d(channel, channel//2, 1, 1)
    def forward(self,x,feature_map):
        #上采样
        up = F.interpolate(x, scale_factor=2, mode='nearest')
        out = self.layer(up)
        return torch.cat((out, feature_map), dim = 1) 

跳跃连接

U-Nets通过跳跃连接,使用在编码器部分学习的细粒度细节在解码器部分构建图像。
损失函数
文中提到的损失函数

整体网络代码

import torch
from torch import nn
from torch.nn import functional as F

class Conv_Block(nn.Module):
    def __init__(self, in_channel,out_channel):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel,out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
        )

    def forward(self,x):
        return self.layer(x)


class DownSample(nn.Module):
    def __init__(self,channel):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(channel,channel, 3, 2, 1, padding_mode="reflect", bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )

    def forward(self,x):
        return self.layer(x)

class UpSample(nn.Module):
    def __init__(self,channel):
        super(UpSample, self).__init__()
        self.layer = nn.Conv2d(channel, channel//2, 1, 1)
    def forward(self,x,feature_map):
        #上采样
        up = F.interpolate(x, scale_factor=2, mode='nearest')
        out = self.layer(up)
        return torch.cat((out, feature_map), dim = 1) #不理解

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.c1 = Conv_Block(3,64)
        self.d1 = DownSample(64)
        self.c2 = Conv_Block(64,128)
        self.d2 = DownSample(128)
        self.c3 = Conv_Block(128,256)
        self.d3 = DownSample(256)
        self.c4 = Conv_Block(256,512)
        self.d4 = DownSample(512)
        self.c5 = Conv_Block(512,1024)
        self.u1 = UpSample(1024)
        self.c6 = Conv_Block(1024,512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512,256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256,128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128,64)
        self.out = nn.Conv2d(64,3, 3, 1, 1)
        self.Th = nn.Sigmoid()

    def forward(self,x):
        R1 = self.c1(x)
        R2 = self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        o1 = self.c6(self.u1(R5, R4))
        o2 = self.c7(self.u2(o1, R3))
        o3 = self.c8(self.u3(o2, R2))
        o4 = self.c9(self.u4(o3, R1))

        return self.Th(self.out(o4))


if __name__ == '__main__':
    x = torch.randn(2,3, 256,256)
    net = Unet()
    print(net(x).shape)

Question

1、为什么要卷积?卷积的意义在哪里,只是为了调整图像的大小吗?可以不调整吗?为什么不能直接下采样?
2、为什么要进行下采样,特征提取有什么好处?
下采样提取特征,你不看见本质你怎么学习?
3、为什么要进上采样,特征融合的好处是什么?为什么上采样的时候要Corp?
下采样相当于你看到了本质,上采样是根据本质组合成想要的目标。
4、多尺度特征融合?增大感受野?不用pool全连接?底层特征?
5、unet的优缺点