【医学图像分割】Dense U-Net的Pytorch代码实现

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

一、DenseNet

在这里插入图片描述

二、代码

找了2个代码,还考虑用H-DenseU-Net的代码。

  1. https://github.com/stefano-malacrino/DenseUNet-pytorch
  2. https://github.com/THUHoloLab/Dense-U-net
    第一次编写代码,感觉很多东西有点冗余,之后在优化,可以运行。(●’◡’●)
import torch
import torch.nn as nn
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x
class Conv_Block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(Conv_Block, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(ch_in),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x = self.conv(x)
        return x
class dens_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(dens_block, self).__init__()#这三个相同吗????
        self.conv1 = Conv_Block(ch_in,ch_out)
        self.conv2 = Conv_Block(ch_out+ch_in, ch_out)
        self.conv3 = Conv_Block(ch_out*2 + ch_in, ch_out)
    def forward(self,input_tensor):
        x1 = self.conv1(input_tensor)
        add1 = torch.cat([x1,input_tensor],dim=1)
        x2 = self.conv2(add1)
        add2 =torch.cat([x1, input_tensor,x2], dim=1)
        x3 = self.conv3(add2)
        return x3
class Conv2D(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(Conv2D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x
class DenseU_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        super(DenseU_Net, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv0 = nn.Conv2d(img_ch,32,kernel_size=7,padding=3,stride=1)
        self.Conv1 = dens_block(ch_in=32, ch_out=64)
        self.Conv2 = dens_block(ch_in=64, ch_out=64)
        self.Conv3 = dens_block(ch_in=64, ch_out=128)
        self.Conv4 = conv_block(ch_in=128, ch_out=256)
        #center
        self.Conv5_1 = Conv2D(ch_in=256,ch_out=512)
        self.Conv5_2 = Conv2D(ch_in=512,ch_out=512)
        self.Drop5 = nn.Dropout(0.5)

        self.Up6 = up_conv(512,512)
        self.add6 = torch.cat
        self.up6 = dens_block(512+256,256)

        self.Up7 = up_conv(256, 256)
        self.add7 = torch.cat
        self.up7 = dens_block(256+128, 128)

        self.Up8 = up_conv(128, 128)
        self.add8 = torch.cat
        self.up8 = dens_block(128+64, 64)

        self.Up9 = up_conv(64, 64)
        self.add9 = torch.cat
        self.up9 = dens_block(64+64, 64)

        self.conv10_1 = nn.Conv2d(64,32,7,1,3)
        self.relu = nn.ReLU(inplace=True)
        self.conv10_2 = nn.Conv2d(32,output_ch,3,1,1)

    def forward(self, x):
        x = self.Conv0(x)#256
        down1 = self.Conv1(x)#256
        pool1 = self.Maxpool(down1)#128
        down2 = self.Conv2(pool1)#128
        pool2 = self.Maxpool(down2)#64
        down3 = self.Conv3(pool2)#64
        pool3 = self.Maxpool(down3)#32
        down4 = self.Conv4(pool3)#32
        pool4 = self.Maxpool(down4)#16
        conv5 = self.Conv5_1(pool4)#16
        conv5 = self.Conv5_2(conv5)#16
        drop5 = self.Drop5(conv5)#16

        up6 = self.Up6(drop5)#32
        # print(up6.shape)
        # print(down4.shape)
        add6 = self.add6([down4,up6],dim=1)
        up6 = self.up6(add6)

        up7 = self.Up7(up6)#64
        add7 = self.add7([down3,up7],dim=1)
        up7 = self.up7(add7)

        up8 = self.Up8(up7)#128
        add8 = self.add8([down2,up8],dim=1)
        up8 = self.up8(add8)

        up9 = self.Up9(up8)#256
        add9 = self.add9([down1,up9],dim=1)
        up9 = self.up9(add9)

        conv10 = self.conv10_1(up9)
        conv10 = self.relu(conv10)
        conv10 = self.conv10_2(conv10)

        return conv10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenseU_Net(img_ch=3, output_ch=1).to(device)
print(model)
input_1 = torch.rand(1,3,256,256).to(device)
print(input_1.shape)
output = model(input_1)
print(output.shape)

三、自己捣鼓的过程(大家就跳过吧)

我主要在DenseU-Net代码方面纠结捣鼓了3天(太菜了┭┮﹏┭┮)。第一天很简单的从网上找代码,找到一个用pytorch的DenseU-Net,写的很复杂,我没怎么看懂,它的密集连接不是封装的,用了很多的编程知识,在把数据集放上之后运行,发现它的准确率很低。之后我就想着和ResU-Net一样调用Pytorch现有的网络。考虑参数量我选择了DenseNet121,发现它的结构3-64-256-512-1024,如果对称来的话中间层通道数为2024,解码器1024-512-256-64-1。参数量很大,在尝试运行之后发现过拟合了,训练集上99%准确率,验证集准确率在97%左右徘徊,没有上升趋势,参数量为1亿。然后考虑把网络调小一点,可能这个时候把项目弄坏了,准确率在20%左右,不管是测试集还是验证集。考虑是不是模型调的太小,又重复实验几次,准确率还是在20%左右。放弃这个方案。之后又转入H-Dense U-net,还是看不懂,Keras转Pytorch失败。
之后看到一篇论文是把DenseU-Net用到视频分割好像是与SLAM有关大概,然后他的框架是keras,没有学过这个框架,就把他的代码变成Pytorch框架,运行20%准确率,突然想起来是不是不是模型架构的事情,是项目本身出了问题,就开始实验,U-Net的准确率也在20%,于是更换项目,实现Dense-Unet。
根据Dense121模型和【医学图像分割网络】之Res U-Net网络PyTorch复现编写的模型,就是太大了,过拟合,而且电脑8太行运行不起来。

"""
Dense121 + U-Net
"""
import torch
from torch import nn
import torchvision.models as models
import torch.nn.functional as F
from torchsummary import summary


class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, d, e=None):
        d = F.interpolate(d, scale_factor=2, mode='bilinear', align_corners=True)
        # concat

        if e is not None:
            cat = torch.cat([e, d], dim=1)
            out = self.block(cat)
        else:
            out = self.block(d)
        return out

def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
    )
    return block


class DenseUnet121_Unet(nn.Module):

    def __init__(self, in_channel, out_channel, pretrained=False):
        super(DenseUnet121_Unet, self).__init__()

        self.densenet = models.densenet121(pretrained=pretrained)
        self.layer0 = nn.Sequential(
            self.densenet.features.conv0,
            self.densenet.features.norm0
        )

        # Encode
        self.denseblock1 = self.densenet.features.denseblock1
        self.denseblock2 = self.densenet.features.denseblock2
        self.denseblock3 = self.densenet.features.denseblock3
        self.denseblock4 = self.densenet.features.denseblock4
        self.transition1 = self.densenet.features.transition1
        self.transition2 = self.densenet.features.transition2
        self.transition3 = self.densenet.features.transition3
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
            nn.Conv2d(in_channels=1024,out_channels=2048,kernel_size=(3, 3),  padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(2048),
            nn.Conv2d(in_channels=2048, out_channels=2048,kernel_size=(3, 3),  padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(2048),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        )

        # Decode
        self.conv_decode4 = expansive_block(1024+2048, 1024, 1024)
        self.conv_decode3 = expansive_block(1024+1024, 512, 512)
        self.conv_decode2 = expansive_block(512+512, 256, 256)
        self.conv_decode1 = expansive_block(256+256, 128, 128)
        self.conv_decode0 = expansive_block(128, 64, 64)
        self.final_layer = final_block(64, out_channel)

    def forward(self, x):
        x = self.layer0(x)
        # Encode
        encode_block1 = self.denseblock1(x)
        encode_block2 = self.denseblock2(self.transition1(encode_block1))
        encode_block3 = self.denseblock3(self.transition2(encode_block2))
        encode_block4 = self.denseblock4(self.transition3(encode_block3))


        # Bottleneck
        bottleneck = self.bottleneck(encode_block4)
        #
        # # Decode
        decode_block4 = self.conv_decode4(bottleneck, encode_block4)
        decode_block3 = self.conv_decode3(decode_block4, encode_block3)
        decode_block2 = self.conv_decode2(decode_block3, encode_block2)
        decode_block1 = self.conv_decode1(decode_block2, encode_block1)
        decode_block0 = self.conv_decode0(decode_block1)
        # #
        final_layer = self.final_layer(decode_block0)
        print(encode_block1.shape)#([1, 256, 64, 64])
        print(encode_block2.shape)#[1, 512, 32, 32]
        print(encode_block3.shape)#[1, 1024, 16, 16]
        print(encode_block4.shape)#[1, 1024, 8, 8]
        print(bottleneck.shape)#[1, 2048, 4, 4]
        print(decode_block4.shape)#[1, 1024, 8, 8]
        print(decode_block3.shape)#
        print(decode_block2.shape)
        print(decode_block1.shape)
        print(decode_block0.shape)
        print(final_layer.shape)
        return final_layer


flag = 0

if flag:
    image = torch.rand(1, 3, 224, 224)
    DenseUnet121_Unet = DenseUnet121_Unet(in_channel=3, out_channel=1)
    mask = DenseUnet121_Unet(image)
    print(mask.shape)

# 测试网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenseUnet121_Unet(in_channel=3, out_channel=1, pretrained=False).to(device)
print(model)
input_1 = torch.rand(1,3,256,256).to(device)
print(input_1.shape)
output = model(input_1)
print(output.shape)



总结

猜你喜欢

转载自blog.csdn.net/goodenough5/article/details/129675987