全卷积神经网路【U-net项目实战】论文中U-Net网络实现

设计神经网络的一般步骤

  1. 设计框架
  2. 设计骨干网络

Unet网络设计的步骤

  1. 设计Unet网络工厂模式
  2. 设计编解码结构
  3. 设计卷积模块
  4. unet实例模块

Unet网络最重要的特征

编解码结构。
2. 解码结构,比FCN更加完善,采用连接方式。
3. 本质是一个框架,编码部分可以使用很多图像分类网络。

示例代码

import torch
import torch.nn as nn

class Unet(nn.Module):
    #初始化参数:Encoder,Decoder,bridge
    #bridge默认值为无,如果有参数传入,则用该参数替换None
    def __init__(self,Encoder,Decoder,bridge = None):
       super(Unet,self).__init__()
       self.encoder = Encoder(encoder_blocks)
       self.decoder = Decoder(decoder_blocks)
       self.bridge = bridge
    def forward(self,x):
        res = self.encoder(x)
        out,skip = res[0],res[1,:]
        if bridge is not None:
            out = bridge(out)
        out = self.decoder(out,skip)
        return out
#设计编码模块
class Encoder(nn.Module):
    def __init__(self,blocks):
        super(Encoder,self).__init__()
        #assert:断言函数,避免出现参数错误
        assert len(blocks) > 0
        #nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数
        self.blocks = nn.Modulelist(blocks)
    def forward(self,x):
        skip = []
        for i in range(len(self.blocks) - 1):
            x = self.blocks[i](x)
            skip.append(x)
        res = [self.block[i+1](x)]
        #列表之间可以通过+号拼接
        res +=  skip
        return res
#设计Decoder模块
class Decoder(nn.Module):
    def __init__(self,blocks):
        super(Decoder, self).__init__()
        assert len(blocks) > 0
        self.blocks = nn.Modulelist(blocks)
    def ceter_crop(self,skips,x):
        _,_,height1,width1 = skips.shape()
        _,_,height2,width2 = x.shape()
        #对图像进行剪切处理,拼接的时候保持对应size参数一致
        ht,wt = min(height1,height2),min(width1,width2)
        dh1 = (height1 - height2)//2 if height1 > height2 else 0
        dw1 = (width1 - width2)//2 if width1 > width2 else 0
        dh2 = (height2 - height1)//2 if height2 > height1 else 0
        dw2 = (width2 - width1)//2 if width2 > width1 else 0
        return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\
               x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)]

    def forward(self, skips,x,reverse_skips = True):
        assert len(skips) == len(blocks) - 1
        if reverse_skips is True:
            skips = skips[: : -1]
        x = self.blocks[0](x)
        for i in range(1, len(self.blocks)):
            skip = skips[i-1]
            x = torch.cat(skip,x,1)
            x = self.blocks[i](x)
        return x
#定义了一个卷积block
def unet_convs(in_channels,out_channels,padding = 0):
    #nn.Sequential:与Modulelist相比,包含了forward函数
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),
        nn.BatchNorm2d(outchannels),
        nn.ReLU(inplace = True),
        nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),
        nn.BatchNorm2d(outchannels),
        nn.ReLU(inplace=True),
    )
#实例化Unet模型
def unet(in_channels,out_channels):
    encoder_blocks = [unet_convs(in_channels, 64),\
                      nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),\
                                    unet_convs(64,128)), \
                      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
                                    unet_convs(128, 256)),
                      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
                                    unet_convs(256, 512)),
                      ]
    bridge = nn.Sequential(unet_convs(512, 1024))
    decoder_blocks = [nn.conTranpose2d(1024, 512), \
                      nn.Sequential(unet_convs(1024, 512),
                                    nn.conTranpose2d(512, 256)),\
                      nn.Sequential(unet_convs(512, 256),
                                    nn.conTranpose2d(256, 128)), \
                      nn.Sequential(unet_convs(512, 256),
                                    nn.conTranpose2d(256, 128)), \
                      nn.Sequential(unet_convs(256, 128),
                                    nn.conTranpose2d(128, 64))
                      ]
    return Unet(encoder_blocks,decoder_blocks,bridge)
发布了650 篇原创文章 · 获赞 190 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/weixin_43838785/article/details/104450064