Image Segmentation-FCN Fully Convolutional Neural Network (Complete Code Detailed Explanation)

Table of contents

FCN fully convolutional neural network

Implementation process

full convolution

deconvolution

Three innovations of FCN

code

FCN fully convolutional neural network

        FCN is the pioneering work of deep learning in the field of semantic segmentation. It proposes to use convolutional layers instead of fully connected operations in CNN to generate heat maps instead of categories.

Implementation process

Figure 1 FCN network structure

        Including the full convolution process and the deconvolution process.

        Full convolution: use the classic CNN network as the skeleton network, for example: Vgg ResNet AlexNet, etc. This article uses Vgg16 as the skeleton network to extract feature maps.

        Deconvolution: Upsample the feature map back (through upsampling methods such as transposed convolution) to restore the original image size.

        Then, the predicted results and the pixels of the real label are classified one by one, which is also called pixel-level classification. Thus, the segmentation problem is transformed into a classification problem.

full convolution

        Blue refers to the convolution operation, and green refers to the pooling operation (the image width and height are halved). Therefore, according to Figure 1, the network structure is: conv1 (2-layer convolution), pool1, conv2 (2-layer convolution), pool2, conv3 (3-layer convolution), pool3 (the first branch of the downward output prediction) , conv4 (3-layer convolution), pool4 (the second branch of downward output prediction), conv5 (3-layer convolution), pool5, conv6, conv7 (the last branch of downward output). The extracted pool3 pool4 conv7 is used for subsequent feature fusion and deconvolution operations.

deconvolution

        FCN is divided into three network structures: FCN-32S, FCN-16s, and FCN-8s.

        FCN-8s acquisition process : Conv7 features are upsampled by 2 times, fused with pool4, fused with 2 times upsampling, fused with pool3, and finally 8 times upsampled to obtain the feature map of the original image size. FCN-32s acquisition process : Conv7 directly upsamples 32 times to obtain the feature map of the original image size.

        Since FCN-8s integrates more layers of features, it has the best effect; while FCN-32s only uses the last layer of conv7 upsampling 32 times for prediction, the feature map is small, and a lot of information is lost.

         Note: FCN-8s (conv7 2x upsampling + pool4) 2x upsampling + pool3 -> 8x upsampling

        During the convolution process, the features are operated by the pool. When hw is an odd number, the h1 w1 of the pooled feature map is not necessarily 1/2 of the original hw, so the shape of the transposed convolution is 2 times upsampled and the original hw There is a difference, so it is necessary to adjust the size of the feature map through the interpolation method torch.nn.functional.interpolate to ensure that it can be integrated with the features of the previous layer.

   

Three innovations of FCN

(1) Full convolution: convert the last fully connected layer of the traditional CNN into a convolutional layer, and realize that the classifier becomes a dense prediction (that is, segmentation).

Specific operation: Change the full connection in the original CNN operation into a convolution operation (see conv6 and conv7 in Figure 1). At this time, the number of featureMaps of the image changes but the image size is still 1/32 of the original image, and the image is no longer called featureMap. Instead, it is called heatMap.

(2) Upsampling: Since the process of extracting features from the skeleton network adopts a series of downsampling (pooling operations), the size of the feature map is reduced. In order to obtain a prediction layer with the same size as the original image, upsampling (such as transposition convolution operation)

(3) Jump structure: Similar to ResNet, the feature maps of different layers are fused, and multiple layers of information can be integrated during classification prediction.

code

FCN-8s implementation process, other network structures can directly modify the forward function implementation in the FCN class.

import torch
from torch import nn
from torchvision.models import vgg16
import torch.nn.functional as F


def vgg_block(num_convs, in_channels, out_channels):
    """
    vgg block: Conv2d ReLU MaxPool2d
    """
    blk = []
    for i in range(num_convs):
        if i == 0:
            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1))
        else:
            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1))
        blk.append(nn.ReLU(inplace=True))
    blk.append(nn.MaxPool2d(kernel_size=(2, 2), stride=2))
    return blk


class VGG16(nn.Module):
    def __init__(self, pretrained=True):
        super(VGG16, self).__init__()
        features = []
        features.extend(vgg_block(2, 3, 64))
        features.extend(vgg_block(2, 64, 128))
        features.extend(vgg_block(3, 128, 256))
        self.index_pool3 = len(features)  # pool3
        features.extend(vgg_block(3, 256, 512))
        self.index_pool4 = len(features)  # pool4
        features.extend(vgg_block(3, 512, 512))  # pool5

        self.features = nn.Sequential(*features)  # 模型容器,有state_dict参数(字典类型)

        """ 将传统CNN中的全连接操作,变成卷积操作conv6 conv7 此时不进行pool操作,图像大小不变,此时图像不叫feature map而是heatmap"""
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)   # conv6
        self.relu = nn.ReLU(inplace=True)
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)  # conv7

        # load pretrained params from torchvision.models.vgg16(pretrained=True)
        if pretrained:
            pretrained_model = vgg16(pretrained=pretrained)
            pretrained_params = pretrained_model.state_dict()  # state_dict()存放训练过程中需要学习的权重和偏置系数,字典类型
            keys = list(pretrained_params.keys())
            new_dict = {}
            for index, key in enumerate(self.features.state_dict().keys()):
                new_dict[key] = pretrained_params[keys[index]]
            self.features.load_state_dict(new_dict)  # load_state_dict必须传入字典对象,将预训练的参数权重加载到features中

    def forward(self, x):
        pool3 = self.features[:self.index_pool3](x)  # 图像大小为原来的1/8
        pool4 = self.features[self.index_pool3:self.index_pool4](pool3)  # 图像大小为原来的1/16
        # pool4 = self.features[:self.index_pool4](x)    # pool4的第二种写法,较浪费时间(从头开始)

        pool5 = self.features[self.index_pool4:](pool4)  # 图像大小为原来的1/32

        conv6 = self.relu(self.conv6(pool5))  # 图像大小为原来的1/32
        conv7 = self.relu(self.conv7(conv6))  # 图像大小为原来的1/32

        return pool3, pool4, conv7


class FCN(nn.Module):
    def __init__(self, num_classes, backbone='vgg'):
        """
        Args:
            num_classes: 分类数目
            backbone: 骨干网络 VGG
        """
        super(FCN, self).__init__()
        if backbone == 'vgg':
            self.features = VGG16()  # 参数初始化

        # 1*1卷积,将通道数映射为类别数
        self.scores1 = nn.Conv2d(4096, num_classes, kernel_size=1)  # 对conv7操作
        self.relu = nn.ReLU(inplace=True)
        self.scores2 = nn.Conv2d(512, num_classes, kernel_size=1)   # 对pool4操作 
        self.scores3 = nn.Conv2d(256, num_classes, kernel_size=1)   # 对pool3操作

        self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=8, stride=8)  # 转置卷积
        self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, stride=2)


    def forward(self, x):
        b, c, h, w = x.shape
        pool3, pool4, conv7 = self.features(x)

        conv7 = self.relu(self.scores1(conv7))

        pool4 = self.relu(self.scores2(pool4))

        pool3 = self.relu(self.scores3(pool3))

        # 融合之前调整一下h w
        conv7_2x = F.interpolate(self.upsample_2x(conv7), size=(pool4.size(2), pool4.size(3)))  # conv7 2倍上采样,调整到pool4的大小
        s=conv7_2x+pool4  # conv7 2倍上采样与pool4融合

        s=F.interpolate(self.upsample_2x(s),size=(pool3.size(2),pool3.size(3)))  # 融合后的特征2倍上采样,调整到pool3的大小
        s = pool3 + s     # 融合后的特征与pool3融合

        out_8s=F.interpolate(self.upsample_8x(s) ,size=(h,w))  # 8倍上采样得到 FCN-8s,得到和原特征x一样大小的特征

        return out_8s

if __name__=='__main__':
    model = FCN(num_classes=12)

    fake_img=torch.randn((4,3,360,480))  # B C H W

    output_8s=model(fake_img)
    print(output_8s.shape)


output:

torch.Size([4, 12, 360, 480])

Guess you like

Origin blog.csdn.net/m0_63077499/article/details/127375650