利用Grad-CAM技术进行深度理解的肺炎CT图像分析

目录

肺炎CT图像识别的重要性

Grad-CAM:一种可视化的解释方法

使用Grad-CAM进行肺炎CT图像识别

项目背景

模型概要

项目内容

ResNet

BasicBlock

BottleneckBlock

CosineWarmup

train

eval

Grad-CAM

CAM

Grad-CAM

总结


在医疗图像分析领域,深度学习已经成为一种重要的工具,它能够帮助医生进行疾病诊断,提高诊断的准确性和效率。本篇博客将分享我们在肺炎CT图像识别中的应用体验,特别是使用了一种名为Grad-CAM的可视化技术。

肺炎CT图像识别的重要性

肺炎是一种严重的呼吸系统疾病,能够造成肺部的炎症和感染。对于疑似肺炎的患者,医生通常会采用CT扫描来检查肺部的状况。然而,由于CT图像的复杂性,需要具有丰富经验的医生才能够准确识别出肺炎。因此,如何使用深度学习来辅助肺炎的CT图像识别,是一个重要且有意义的课题。

Grad-CAM:一种可视化的解释方法

Grad-CAM是一种可视化的解释方法,它可以生成热力图来显示输入图像中哪些部分对于模型的预测最为重要。这对于理解模型的决策过程非常有帮助,尤其是在医疗图像分析中。

使用Grad-CAM进行肺炎CT图像识别

在我们的应用中,我们首先训练了一个深度学习模型来识别肺炎CT图像。然后,我们使用Grad-CAM来生成热力图。这些热力图清晰地显示了模型在做出预测时,主要关注了图像中的哪些区域。通过观察这些热力图,我们不仅能够验证模型的预测是否合理,还能发现可能被医生忽略的潜在病灶区域。

项目背景

新型冠状病毒肺炎(Corona Virus Disease 2019,COVID-19),简称“新冠肺炎”,世界卫生组织命名为“2019冠状病毒病”,是指2019新型冠状病毒感染导致的肺炎。新冠肺炎的爆发,是一场世界性的灾难。

新冠肺炎患者肺部X-射线影像中有病毒感染的特征表现。因此,准确识别出肺部X-射线影像中的新冠肺炎阳性影像,具有十分重要的现实意义。项目旨在使用新冠肺炎CT数据集训练深度学习算法,以协助医生快速、准确地判断患者是否感染新冠肺炎。

项目采用来自卡塔尔大学和孟加拉国达卡大学的一组研究人员以及来自巴基斯坦和马来西亚的合作者同医生合作,创建的一个新冠肺炎CT数据集。包含3616例COVID-19阳性10192例正常6012例非COVID肺部感染1345例病毒性肺炎。模型选用经典网络结构ResNet50,采用热启动的余弦退火学习率优化策略,测试集准确度可达95%以上。

医学免责声明:95%仅为实验数据集上的结果,任何临床使用的算法需要在实际使用环境下进行测试,本模型结果不可作为临床诊疗依据。

模型概要

众所周知,对于浅层网络,其模型性能会随着网络层的堆叠而提升,因为非线性层增多,特征提取的能力越强,即模型拟合数据的能力越强,所以从AlexNet到VGG,深度学习模型层数越来越多。但当继续加深时,模型性能不升反降,因为更深的网络会导致梯度消失问题,从而阻碍收敛,即模型退化问题

图-1 退化问题

GoogLeNet依靠两个辅助loss将网络撑到22层并取得2014年ILSVRC比赛的冠军,但增加辅助loss的方法似乎治标不治本,否则GoogLeNet也不会增加区区三层即止,给人一种吊着氧气瓶赢得马拉松的感觉。2015年ResNet横空出世,使用残差结构打破深度神经网络的任督二脉!从此DNN层数开始成百上千。

作为2015年ILSVRC比赛的冠军,ResNet在分类、检测、定位均表现优异。为解决退化问题,ResNet采用跨层连接的方法,图-2是论文中介绍的ResNet基本残差块的结构:

图-2 基本残差块

一般plain网络层输出y = F(x),而残差块residual block输出y = F(x) + x。残差块额外提供一条identity路径(short cut)。identity mapping称为恒等映射,即输入和输出是相等的。使用残差块的好处是:如果增加的层并未增加网络性能,则训练使得F(x)趋近于0,这样增加的层的输出y也趋近于输入x,相当于没有增加这个层。图-3是对比18层和34层的普通plain网络和残差块residual block的训练结果:

图-3 plain networks VS residual block

残差块分为两种,一种如图-4右侧所示的瓶颈结构(Bottleneck),Bottleneck主要用于降低计算复杂度,输入数据先经过1x1卷积层减少通道数,再经过3x3卷积层提取特征,最后经过1x1卷积层恢复通道数。通道数先减少再恢复,就像一个中间细两头粗的瓶颈,所以被称为Bottleneck。另一种如图-4左侧所示的Basic Block,由2个3×3卷积层构成。Bottleneck Block被用于ResNet50、ResNet101和ResNet152,而Basic Block被用于ResNet18和ResNet34。

图-4 Basic Block and Bottleneck Block

short cut路径也分为两种,如图-5所示,当残差路径输出与输入的通道数和特征图尺寸均相同时,short cut路径将输入x原封不动地输出。若残差路径输出与输入的通道数或特征图尺寸不同时,short cut路径使用1x1卷积层对输入x进行调整,使得short cut路径输出与残差路径输出的通道数和特征图尺寸均相同。

图-5 short cut

项目内容

In [ ]

import os
import cv2
import glob
import paddle
import numpy as np
import prettytable
import matplotlib.pyplot as plt
import paddle.nn.functional as F
from paddle.io import Dataset
from paddle.optimizer.lr import LinearWarmup, CosineAnnealingDecay
from paddle.vision.transforms import Compose, Resize, ToTensor, Normalize
from paddle.nn import Sequential, Conv2D, BatchNorm2D, ReLU, MaxPool2D, AdaptiveAvgPool2D, Flatten, Linear

In [ ]

# 解压数据集
!unzip /home/aistudio/data/data179597/dataset.zip -d work/

In [ ]

# 划分数据集
base_dir = "/home/aistudio/work/dataset/"
img_dirs = ["COVID", "LungOpacity", "Normal", "ViralPneumonia"]
file_names = ["train.txt", "val.txt", "test.txt"]
splits = [0, 0.7, 0.9, 1] # 7 : 2 : 1 划分

for split_idx, file_name in enumerate(file_names):
    with open(os.path.join("/home/aistudio/work/dataset", file_name), "w") as f:
        for label, img_dir in enumerate(img_dirs):
            imgs = os.listdir(os.path.join(base_dir, img_dir))
            for idx in range(int(splits[split_idx] * len(imgs)), int(splits[split_idx + 1] * len(imgs))):
                print("{} {}".format(img_dir + "/" + imgs[idx], label), file=f)

In [3]

# 计算均值和标准差
def get_mean_std(img_paths):
    print('Total images:', len(img_paths))
    # MAX, MIN = np.zeros(3), np.ones(3) * 255
    mean, std = np.zeros(3), np.zeros(3)
    for img_path in img_paths:
        img = cv2.imread(img_path)
        for c in range(3):
            mean[c] += img[:, :, c].mean()
            std[c] += img[:, :, c].std()
            # MAX[c] = max(MAX[c], img[:, :, c].max())
            # MIN[c] = min(MAX[c], img[:, :, c].min())
    mean /= len(img_paths)
    std /= len(img_paths)
    # mean /= MAX - MIN
    # std /= MAX - MIN
    return mean, std
img_paths = []
img_paths.extend(glob.glob(os.path.join("work/dataset/COVID", "*.png")))
img_paths.extend(glob.glob(os.path.join("work/dataset/LungOpacity", "*.png")))
img_paths.extend(glob.glob(os.path.join("work/dataset/Normal", "*.png")))
img_paths.extend(glob.glob(os.path.join("work/dataset/ViralPneumonia", "*.png")))
mean, std = get_mean_std(img_paths)
print('mean:', mean)
print('std:', std)
Total images: 21165
mean: [129.90909919 129.90909919 129.90909919]
std: [59.0183735 59.0183735 59.0183735]

In [ ]

# 自定义数据集
class CovidDataset(Dataset):
    def __init__(self, base_dir, label_path, transform=None):
        super(CovidDataset, self).__init__()
        self.datas = []
        with open(label_path) as f:
            for line in f.readlines():
                img_path, label = line.strip().split(" ")
                img_path = os.path.join(base_dir, img_path)
                self.datas.append([img_path, label])
        self.transform = transform # 数据处理方法

    def __getitem__(self, idx):
        img_path, label = self.datas[idx]
        img = cv2.imread(img_path)
        img = img.astype("float32") # paddle训练时数据格式默认为float32
        if self.transform is not None:
            img = self.transform(img)
        label = np.array([int(label)]) # cross_entropy要求label格式为int
        return img, label

    def __len__(self):
        return len(self.datas)

In [5]

# 数据预处理
transform = Compose([Resize(size=224), 
                    ToTensor(), # numpy.ndarray -> paddle.Tensor   HWC -> CHW
                    Normalize(mean=[129.90909919, 129.90909919, 129.90909919], std=[59.0183735, 59.0183735, 59.0183735], data_format='CHW')])
train_dataset = CovidDataset("work/dataset", "work/dataset/train.txt", transform)
val_dataset = CovidDataset("work/dataset", "work/dataset/val.txt", transform)
test_dataset = CovidDataset("work/dataset", "work/dataset/test.txt", transform)
print("训练集图片数量: {}\n验证集图片数量: {}\n测试集图片数量: {}".format(len(train_dataset), len(val_dataset), len(test_dataset)))
训练集图片数量: 14814
验证集图片数量: 4232
测试集图片数量: 2119

ResNet

BasicBlock

In [ ]

# 定义BasicBlock
class BasicBlock(paddle.nn.Layer):
    def __init__(self, in_channels, out_channels, stride):
        super(BasicBlock, self).__init__()
        self.conv1 = Sequential(
            Conv2D(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias_attr=False), # bias_attr=False 不添加偏置
            BatchNorm2D(out_channels), 
            ReLU()
        )
        self.conv2 = Sequential(
            Conv2D(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias_attr=False), 
            BatchNorm2D(out_channels)
        )
        # 当输入通道数和输出通道数不同或特征图尺寸不同时 shortcut路径使用1x1卷积层对输入进行调整
        if stride != 1 or in_channels != out_channels:
            self.shortcut = Sequential(
                Conv2D(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias_attr=False), 
                BatchNorm2D(out_channels)
            )
        else:
            self.shortcut = Sequential()
        self.relu = ReLU()
    
    def forward(self, inputs):
        out_conv1 = self.conv1(inputs)
        out_conv2 = self.conv2(out_conv1)
        outputs = self.relu(out_conv2 + self.shortcut(inputs))
        return outputs

BottleneckBlock

In [ ]

# 定义BottleneckBlock
class BottleneckBlock(paddle.nn.Layer):
    def __init__(self, in_channels, out_channels, stride):
        super(BottleneckBlock, self).__init__()
        self.conv1 = Sequential(
            Conv2D(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0, bias_attr=False), 
            BatchNorm2D(out_channels // 4), 
            ReLU()
        )
        self.conv2 = Sequential(
            Conv2D(out_channels // 4, out_channels // 4, kernel_size=3, stride=stride, padding=1, bias_attr=False), 
            BatchNorm2D(out_channels // 4), 
            ReLU()
        )
        self.conv3 = Sequential(
            Conv2D(out_channels // 4, out_channels, kernel_size=1, stride=1, padding=0, bias_attr=False), 
            BatchNorm2D(out_channels)
        )
        # 当输入通道数和输出通道数不同或特征图尺寸不同时 shortcut路径使用1x1卷积层对输入进行调整
        if stride != 1 or in_channels != out_channels:
            self.shortcut = Sequential(
                Conv2D(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias_attr=False), 
                BatchNorm2D(out_channels)
            )
        else:
            self.shortcut = Sequential()
        self.relu = ReLU()

    def forward(self, inputs):
        out_conv1 = self.conv1(inputs)
        out_conv2 = self.conv2(out_conv1)
        out_conv3 = self.conv3(out_conv2) 
        outputs = self.relu(out_conv3 + self.shortcut(inputs))
        return outputs

In [ ]

# 定义ResNet
class ResNet(paddle.nn.Layer):
    def __init__(self, layers, num_classes):
        super(ResNet, self).__init__()
        config = {
            18: {'block_type': BasicBlock, 'num_blocks': [2, 2, 2, 2], 'out_channels': [64, 128, 256, 512]}, 
            34: {'block_type': BasicBlock, 'num_blocks': [3, 4, 6, 3], 'out_channels': [64, 128, 256, 512]}, 
            50: {'block_type': BottleneckBlock, 'num_blocks': [3, 4, 6, 3], 'out_channels': [256, 512, 1024, 2048]}, 
            101: {'block_type': BottleneckBlock, 'num_blocks': [3, 4, 23, 3], 'out_channels': [256, 512, 1024, 2048]}, 
            152: {'block_type': BottleneckBlock, 'num_blocks': [3, 8, 36, 3], 'out_channels': [256, 512, 1024, 2048]}
        }
        self.conv = Sequential(
            Conv2D(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias_attr=False), 
            BatchNorm2D(64), 
            ReLU(), 
        )
        self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
        in_channels = 64
        block_list = []
        for i, block_num in enumerate(config[layers]['num_blocks']):
            for order in range(block_num):
                block_list.append(config[layers]['block_type'](in_channels, config[layers]['out_channels'][i], 2 if order == 0 and i != 0 else 1))
                in_channels = config[layers]['out_channels'][i]
        self.block = Sequential(*block_list)
        self.avg_pool = AdaptiveAvgPool2D(1) # 自适应平均池化
        self.flatten = Flatten() # 展平
        self.fc = Linear(config[layers]['out_channels'][-1], num_classes)
    
    def forward(self, inputs):
        out_conv = self.conv(inputs)
        out_max_pool = self.max_pool(out_conv)
        out_block = self.block(out_max_pool)
        out_avg_pool = self.avg_pool(out_block)
        out_flatten = self.flatten(out_avg_pool)
        outputs = self.fc(out_flatten)
        return outputs

图-6 ResNet网络结构

In [9]

# 查看网络结构
resnet50 = ResNet(50, 4)
paddle.summary(resnet50, (1, 3, 224, 224))
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
     Conv2D-54       [[1, 3, 224, 224]]   [1, 64, 112, 112]        9,408     
  BatchNorm2D-54    [[1, 64, 112, 112]]   [1, 64, 112, 112]         256      
      ReLU-50       [[1, 64, 112, 112]]   [1, 64, 112, 112]          0       
    MaxPool2D-2     [[1, 64, 112, 112]]    [1, 64, 56, 56]           0       
     Conv2D-55       [[1, 64, 56, 56]]     [1, 64, 56, 56]         4,096     
  BatchNorm2D-55     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-51        [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
     Conv2D-56       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-56     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-52        [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
     Conv2D-57       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-57     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
     Conv2D-58       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-58     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
      ReLU-53        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
BottleneckBlock-17   [[1, 64, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-59       [[1, 256, 56, 56]]    [1, 64, 56, 56]        16,384     
  BatchNorm2D-59     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-54        [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
     Conv2D-60       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-60     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-55        [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
     Conv2D-61       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-61     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
      ReLU-56        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
BottleneckBlock-18   [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-62       [[1, 256, 56, 56]]    [1, 64, 56, 56]        16,384     
  BatchNorm2D-62     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-57        [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
     Conv2D-63       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-63     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-58        [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
     Conv2D-64       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-64     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
      ReLU-59        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
BottleneckBlock-19   [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-65       [[1, 256, 56, 56]]    [1, 128, 56, 56]       32,768     
  BatchNorm2D-65     [[1, 128, 56, 56]]    [1, 128, 56, 56]         512      
      ReLU-60        [[1, 128, 56, 56]]    [1, 128, 56, 56]          0       
     Conv2D-66       [[1, 128, 56, 56]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-66     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-61        [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
     Conv2D-67       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-67     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
     Conv2D-68       [[1, 256, 56, 56]]    [1, 512, 28, 28]       131,072    
  BatchNorm2D-68     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
      ReLU-62        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
BottleneckBlock-20   [[1, 256, 56, 56]]    [1, 512, 28, 28]          0       
     Conv2D-69       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-69     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-63        [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
     Conv2D-70       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-70     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-64        [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
     Conv2D-71       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-71     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
      ReLU-65        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
BottleneckBlock-21   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-72       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-72     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-66        [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
     Conv2D-73       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-73     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-67        [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
     Conv2D-74       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-74     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
      ReLU-68        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
BottleneckBlock-22   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-75       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-75     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-69        [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
     Conv2D-76       [[1, 128, 28, 28]]    [1, 128, 28, 28]       147,456    
  BatchNorm2D-76     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-70        [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
     Conv2D-77       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-77     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
      ReLU-71        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
BottleneckBlock-23   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
     Conv2D-78       [[1, 512, 28, 28]]    [1, 256, 28, 28]       131,072    
  BatchNorm2D-78     [[1, 256, 28, 28]]    [1, 256, 28, 28]        1,024     
      ReLU-72        [[1, 256, 28, 28]]    [1, 256, 28, 28]          0       
     Conv2D-79       [[1, 256, 28, 28]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-79     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-73        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-80       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-80    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
     Conv2D-81       [[1, 512, 28, 28]]   [1, 1024, 14, 14]       524,288    
  BatchNorm2D-81    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
      ReLU-74       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-24   [[1, 512, 28, 28]]   [1, 1024, 14, 14]          0       
     Conv2D-82      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-82     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-75        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-83       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-83     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-76        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-84       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-84    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
      ReLU-77       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-25  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-85      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-85     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-78        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-86       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-86     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-79        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-87       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-87    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
      ReLU-80       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-26  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-88      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-88     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-81        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-89       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-89     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-82        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-90       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-90    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
      ReLU-83       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-27  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-91      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-91     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-84        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-92       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-92     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-85        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-93       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-93    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
      ReLU-86       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-28  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-94      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-94     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-87        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-95       [[1, 256, 14, 14]]    [1, 256, 14, 14]       589,824    
  BatchNorm2D-95     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-88        [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
     Conv2D-96       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-96    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
      ReLU-89       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
BottleneckBlock-29  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
     Conv2D-97      [[1, 1024, 14, 14]]    [1, 512, 14, 14]       524,288    
  BatchNorm2D-97     [[1, 512, 14, 14]]    [1, 512, 14, 14]        2,048     
      ReLU-90        [[1, 512, 14, 14]]    [1, 512, 14, 14]          0       
     Conv2D-98       [[1, 512, 14, 14]]     [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-98      [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-91         [[1, 512, 7, 7]]      [1, 512, 7, 7]           0       
     Conv2D-99        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-99     [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
    Conv2D-100      [[1, 1024, 14, 14]]    [1, 2048, 7, 7]       2,097,152   
  BatchNorm2D-100    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
      ReLU-92        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
BottleneckBlock-30  [[1, 1024, 14, 14]]    [1, 2048, 7, 7]           0       
    Conv2D-101       [[1, 2048, 7, 7]]      [1, 512, 7, 7]       1,048,576   
  BatchNorm2D-101     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-93         [[1, 512, 7, 7]]      [1, 512, 7, 7]           0       
    Conv2D-102        [[1, 512, 7, 7]]      [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-102     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-94         [[1, 512, 7, 7]]      [1, 512, 7, 7]           0       
    Conv2D-103        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-103    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
      ReLU-95        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
BottleneckBlock-31   [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    Conv2D-104       [[1, 2048, 7, 7]]      [1, 512, 7, 7]       1,048,576   
  BatchNorm2D-104     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-96         [[1, 512, 7, 7]]      [1, 512, 7, 7]           0       
    Conv2D-105        [[1, 512, 7, 7]]      [1, 512, 7, 7]       2,359,296   
  BatchNorm2D-105     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
      ReLU-97         [[1, 512, 7, 7]]      [1, 512, 7, 7]           0       
    Conv2D-106        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-106    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
      ReLU-98        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
BottleneckBlock-32   [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
AdaptiveAvgPool2D-2  [[1, 2048, 7, 7]]     [1, 2048, 1, 1]           0       
     Flatten-2       [[1, 2048, 1, 1]]        [1, 2048]              0       
     Linear-2           [[1, 2048]]             [1, 4]             8,196     
===============================================================================
Total params: 23,569,348
Trainable params: 23,463,108
Non-trainable params: 106,240
-------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.57
Params size (MB): 89.91
Estimated Total Size (MB): 377.05
-------------------------------------------------------------------------------

{'total_params': 23569348, 'trainable_params': 23463108}

CosineWarmup

Momentum优化器

如果初始学习率设置得当并且迭代轮数充足,该优化器会在众多的优化器中脱颖而出,使其在验证集上获得更高的准确率。但Momentum优化器有两个缺点,一是收敛速度慢(较之Adam、AdamW等自适应优化器),二是初始学习率的设置需要依靠大量经验。

Warmup

Warmup是在ResNet论文中提到的一种学习率预热方法,它在训练开始时先使用一个较小的学习率训练一些epochs或者steps,再修改为预先设置的学习率进行训练。由于刚开始训练时,模型的权重是随机初始化的,若此时选择一个较大的学习率,可能会导致模型的不稳定(振荡),选择Warmup预热学习率的方式,可以使得开始训练时的一些epochs或者steps内学习率较小,在小的学习率下,模型可以慢慢趋于稳定,等模型相对稳定后再选择预先设置的学习率进行训练,使得模型收敛速度更快,模型效果更佳。

余弦退火策略

在使用梯度下降算法来优化目标函数时,当越来越接近loss的全局最小值时,学习率应该变得更小来使得模型尽可能接近这一最低点,而余弦退火(Cosine annealing)可以通过余弦函数来降低学习率。余弦函数中随着x的增加余弦值首先缓慢下降,然后加速下降,最后缓慢下降。这种下降模式能和学习率配合,以一种十分有效的计算方式来产生很好的效果。

热启动的余弦退火学习率优化策略CosineWarmup非常实用,本项目选择使用Momentum优化器加CosineWarmup策略的组合替换传统SGD优化器。

In [10]

# 热启动的余弦退火学习率优化策略
class Cosine(CosineAnnealingDecay):
    def __init__(self, learning_rate, step_each_epoch, epoch_num, **kwargs):
        super(Cosine, self).__init__(learning_rate=learning_rate, T_max=step_each_epoch * epoch_num)

class CosineWarmup(LinearWarmup):
    def __init__(self, learning_rate, step_each_epoch, epoch_num, warmup_epoch_num=5, **kwargs):
        assert epoch_num > warmup_epoch_num, "epoch_num({}) should be larger than warmup_epoch_num({}) in CosineWarmup.".format(epoch_num, warmup_epoch_num)
        warmup_steps = warmup_epoch_num * step_each_epoch
        start_lr = 0.0
        end_lr = learning_rate
        learning_rate = Cosine(learning_rate, step_each_epoch, epoch_num - warmup_epoch_num)
        super(CosineWarmup, self).__init__(learning_rate=learning_rate, warmup_steps=warmup_steps, start_lr=start_lr, end_lr=end_lr)

train

In [ ]

# 训练
def train(model):
    epoch_num = 50
    batch_size = 50
    learning_rate = 0.01
    train_loss_list = []
    train_acc_list = []
    eval_loss_list = []
    eval_acc_list = []
    iter = 0
    iters = []
    epochs = []
    max_eval_acc = 0

    model.train()
    train_loader = paddle.io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = paddle.io.DataLoader(val_dataset, batch_size=batch_size)
    scheduler = CosineWarmup(learning_rate=learning_rate, step_each_epoch=int(len(train_dataset) / batch_size), epoch_num=epoch_num, verbose=True)
    opt = paddle.optimizer.Momentum(learning_rate=scheduler, parameters=model.parameters()) # Momentum + CosineWarmup

    for epoch_id in range(epoch_num):
        for batch_id, (images, labels) in enumerate(train_loader()):
            predicts = model(images)
            loss = F.cross_entropy(predicts, labels)
            acc = paddle.metric.accuracy(predicts, labels)
            if batch_id % 10 == 0:
                train_loss_list.append(loss.item())
                train_acc_list.append(acc.item())
                iters.append(iter)
                iter += 10
                print("epoch: {}, batch: {}, learning_rate: {}, \ntrain loss is: {}, train acc is: {}".format(epoch_id, batch_id, opt.get_lr(), loss.item(), acc.item()))
            loss.backward() # 反向传播
            opt.step() # 更新参数
            opt.clear_grad() # 清除梯度
            scheduler.step() # 更新参数
        
        # 每个epoch评估一次
        model.eval()
        loss_list = []
        acc_list = []
        results = np.zeros([4, 4], dtype='int64')
        for batch_id, (images, labels) in enumerate(val_loader()):
            predicts = model(images)
            for i in range(len(images)):
                results[labels[i].item()][paddle.argmax(predicts[i]).item()] += 1
            loss = F.cross_entropy(predicts, labels)
            acc = paddle.metric.accuracy(predicts, labels)
            loss_list.append(loss.item())
            acc_list.append(acc.item())
        eval_loss, eval_acc = np.mean(loss_list), np.mean(acc_list)
        eval_loss_list.append(eval_loss)
        eval_acc_list.append(eval_acc)
        epochs.append(epoch_id)
        model.train()
        print("eval loss: {}, eval acc: {}".format(eval_loss, eval_acc))
        # 保存最优模型
        if eval_acc > max_eval_acc:
            paddle.save(model.state_dict(), 'COVID.pdparams')
            max_eval_acc = eval_acc

        results_table = prettytable.PrettyTable()
        results_table.field_names = ['Type', 'Precision', 'Recall', 'F1_Score']
        class_names = ['COVID', 'LungOpacity', 'Normal', 'ViralPneumonia']
        for i in range(4):
            precision = results[i][i] / results.sum(axis=0)[i]
            recall = results[i][i] / results.sum(axis=1)[i]
            results_table.add_row([class_names[i], 
                                    np.round(precision, 3), 
                                    np.round(recall, 3), 
                                    np.round(precision * recall * 2 / (precision + recall), 3)])
        print(results_table)

    return train_loss_list, train_acc_list, eval_loss_list, eval_acc_list, iters, epochs

resnet50 = ResNet(50, 4)
train_loss_list, train_acc_list, eval_loss_list, eval_acc_list, iters, epochs = train(resnet50)

In [14]

# 训练过程可视化
def plot(freq, list, xlabel, ylabel, title):
    plt.figure()
    plt.title(title, fontsize='x-large')
    plt.xlabel(xlabel, fontsize='large')
    plt.ylabel(ylabel, fontsize='large')
    plt.plot(freq, list, color='red')
    plt.grid()
    plt.show()

plot(iters, train_loss_list, 'iter', 'loss', 'train loss')
plot(iters, train_acc_list, 'iter', 'acc', 'train acc')
plot(epochs, eval_loss_list, 'epoch', 'loss', 'eval loss')
plot(epochs, eval_acc_list, 'epoch', 'acc', 'eval acc')

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

eval

In [9]

# 评估
def eval(model):
    batch_size = 50
    loss_list = []
    acc_list = []
    results = np.zeros([4, 4], dtype='int64')
    params_file_path = 'COVID.pdparams'
    # 加载模型参数
    param_dict = paddle.load(params_file_path)
    model.load_dict(param_dict)
    model.eval()
    test_loader = paddle.io.DataLoader(test_dataset, batch_size=batch_size)
    
    for batch_id, (images, labels) in enumerate(test_loader()):
        predicts = model(images)
        for i in range(len(images)):
            results[labels[i].item()][paddle.argmax(predicts[i]).item()] += 1
        loss = F.cross_entropy(predicts, labels)
        acc = paddle.metric.accuracy(predicts, labels)
        loss_list.append(loss.item())
        acc_list.append(acc.item())
    eval_loss, eval_acc = np.mean(loss_list), np.mean(acc_list)
    print("eval_loss: {}, eval_acc: {}".format(eval_loss, eval_acc))

    results_table = prettytable.PrettyTable()
    results_table.field_names = ['Type', 'Precision', 'Recall', 'F1_Score']
    class_names = ['COVID', 'LungOpacity', 'Normal', 'ViralPneumonia']
    for i in range(4):
        precision = results[i][i] / results.sum(axis=0)[i]
        recall = results[i][i] / results.sum(axis=1)[i]
        results_table.add_row([class_names[i], 
                                np.round(precision, 3), 
                                np.round(recall, 3), 
                                np.round(precision * recall * 2 / (precision + recall), 3)])
    print(results_table)

resnet50 = ResNet(50, 4)
eval(resnet50)
eval_loss: 0.1824716518641226, eval_acc: 0.965116277683613
+----------------+-----------+--------+----------+
|      Type      | Precision | Recall | F1_Score |
+----------------+-----------+--------+----------+
|     COVID      |   0.981   |  1.0   |   0.99   |
|  LungOpacity   |   0.935   | 0.958  |  0.947   |
|     Normal     |   0.978   | 0.951  |  0.964   |
| ViralPneumonia |   0.957   |  1.0   |  0.978   |
+----------------+-----------+--------+----------+

In [10]

# 预测图像
def predict(img_path):
    model = ResNet(50, 4)
    # 加载模型参数
    model.load_dict(paddle.load('COVID.pdparams'))
    model.eval()
    img = cv2.imread(img_path)
    plt.imshow(img[:, :, ::-1]) # BGR -> RGB
    plt.show()
    img = paddle.reshape(transform(img.astype('float32')), [-1, 3, 224, 224])
    # 返回每个分类标签的对应概率
    results = model(img)
    # 概率最大的标签作为预测结果
    classes = ["COVID", "LungOpacity", "Normal", "ViralPneumonia"]
    label = paddle.argmax(results).item()
    predict_result = classes[label]
    print(predict_result)

predict("work/dataset/COVID/COVID-2838.png")
predict("work/dataset/LungOpacity/LungOpacity-1824.png")
predict("work/dataset/Normal/Normal-07525.png")
predict("work/dataset/ViralPneumonia/ViralPneumonia-0045.png")

<Figure size 640x480 with 1 Axes>
COVID

<Figure size 640x480 with 1 Axes>
LungOpacity

<Figure size 640x480 with 1 Axes>
Normal

<Figure size 640x480 with 1 Axes>
ViralPneumonia

Grad-CAM

Grad-CAM(Gradient-weighted Class Activation Mapping)梯度加权类激活图,其前身为CAM(Class Activation Mapping)类激活图。CAM可以理解为对预测输出的贡献分布,分数越高的地方表示原始图片对应区域对网络的响应越高、贡献越大,即表示每个位置对该类别的重要程度。Grad-CAM是在CAM基础上的改进与泛化,使其能够用于更广泛的模型结构上,并进一步提升突出重点区域的能力。

CAM

一般DNN的结构如图-7所示:模型前面是堆叠在一起不断降低输出特征图尺寸、增加通道数的卷积层,用于提取图片各个粒度的特征,后面接一个GAP(全局平均池化)层得到各个通道特征图的均值,最后接一个softmax激活的全连接层输出各个类别的判别概率。最终模型输出的每一个类别的判别概率就是最后全连接层对应此类别的权重乘以前面GAP层输出的特征图均值得到的。这个值越大模型最终输出此类别的概率就越大,是模型判别最终输出类别的关键。

图-7 CAM

CAM就是从这个值的意义出发来设计的。全连接层权重与GAP层输出的特征图均值乘积能够决定模型最终输出的类别,但是为了最终输出一个代表概率的值,GAP层将最后一个卷积层提取的特征图从二维降至一维,失去了空间特征信息。如果我们将最后一个卷积层提取的二维特征图不经过GAP层直接与最后的全连接层的权重相乘,不就既能保留二维特征图的二维空间特性,又能反应特征图对当前分类输出的重要性了么?其实,这就是CAM,计算公式如下所示:

图-8 CAM公式

其中��(�,�)Mc(x,y)表示计算得到的针对类别C的类激活图,��(�,�)fk(x,y)表示最后一个卷积层提取的特征图,���wkc​表示最后一个全连接层计算类别c概率的权重。

Grad-CAM

既然CAM已经能够展现模型的重点关注区域,那为什么还要发展Grad-CAM呢?因为CAM要求模型结构中必须要包含一个GAP层,如果没有就要加入一个GAP层。这对一些已经训练好的模型很不方便,从而限制CAM的适用范围。而Grad-CAM正是为克服这一局限而设计的。

CAM公式如下:

图-9 CAM公式

Grad-CAM公式如下:

图-10 Grad-CAM公式

ReLU的目的是在最后加和各个通道的激活图时只加和权重为正值的,以消除激活图上一些与目标类别无关的干扰(仅关注对最终预测分类有正向影响的特征)。

Grad-CAM公式里的��Ak和CAM公式里的��(�,�)fk(x,y)均表示最后一个卷积层提取的特征图。两个公式中剩下的唯一不同部分,也是最重要的部分就是特征图的激活加权方式。在CAM公式中是通过乘上���wkc​给各个通道的特征图进行激活加权的,其表示经过GAP后最后一个全连接层中激活目标类别c的k通道的权重,实现算法时将这部分权重从全连接层中剥离出来即可,在Grad-CAM公式中给特征图进行激活加权是通过���αkc​这部分实现的。

���αkc​是通过对最后一个卷积层的梯度进行GAP操作得到的,公式如下:

图-11 梯度加权

等式右边左半部份表示GAP操作,右半部份的∂��∂����∂Aijk​∂yc​表示针对目标类别c的loss对最后一个卷积层提取的特征图的梯度,其通过对模型的计算图进行反向梯度传播得到。

In [ ]

# GradCAM
from gradcam import GradCAM

model = ResNet(50, 4)
model.load_dict(paddle.load('COVID.pdparams'))
# 指定卷积层
layer = 'block.15'
gradcam = GradCAM(model, layer)
# 查看网络层
GradCAM.show_network(model)

In [ ]

# 批量生成GradCAM
def grad_cam(img_dir):
    img_list = os.listdir(img_dir)
    img_list = filter(lambda x: '.png' in x, img_list)
    for img_file in img_list:
        img_path = os.path.join(img_dir, img_file)
        img = cv2.imread(img_path)
        save_dir = os.path.split(img_dir)[-1]
        save_path = os.path.join("/home/aistudio/work/gradcam", f'{save_dir}')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        save_path = os.path.join(save_path, f'{img_file}')
        gradcam.save(img, file=save_path)

grad_cam("/home/aistudio/work/dataset/COVID")

In [13]

# 批量展示GradCAM
def show_cam(img_dir, cam_dir):
    img_list = os.listdir(img_dir)
    img_list = filter(lambda x: '.png' in x, img_list)
    img_list = [os.path.join(img_dir, img_file) for img_file in img_list]
    img_list.sort(key=lambda x : x[-8:])
    cam_list = os.listdir(cam_dir)
    cam_list = filter(lambda x: '.png' in x, cam_list)
    cam_list = [os.path.join(cam_dir, cam_file) for cam_file in cam_list]
    cam_list.sort(key=lambda x : x[-8:])
    show_list = img_list[:8] + cam_list[:8]

    for i, path in enumerate(show_list):
        img = cv2.imread(path)
        img = img[:, :, ::-1] # BGR -> RGB
        plt.subplot(4, 4, i + 1)
        plt.imshow(img)

    plt.show()

show_cam("/home/aistudio/work/dataset/COVID", "/home/aistudio/work/gradcam/COVID")

<Figure size 640x480 with 16 Axes>

总结

本项目基于ResNet实现新冠肺炎CT图像识别的全流程,从数据集开始:梳理结构、重命名文件(此操作不在报告内赘述),划分数据集,计算均值和标准差,自定义数据集(paddle.io.Dataset),数据集预处理。到搭建ResNet网络结构:定义BasicBlock,定义BottleneckBlock,定义ResNet,查看网络结构。最后训练和评估模型,实现图像预测。

项目选择使用Momentum优化器加CosineWarmup策略的组合替换传统SGD优化器,一定程度上提升模型训练效果,最终在测试集上准确度可达95%以上。本项目以学习目的为主,实现过程清晰详尽,如果想获得更好的效果,可以考虑探索以下方向:更换模型,使用更为复杂有效的数据集切分方案,训练更多轮数、调整学习率,微调超参数,数据增强等。

猜你喜欢

转载自blog.csdn.net/m0_68036862/article/details/131348726