MS-Model【3】:Medical Transformer


前言

本文是 2021 年发表在 MICCAI 上的一篇文章,在当年的多项医学图像分割任务挑战中都获得了不错的成绩。本文主要介绍了这篇论文提出的 Medical Transformer 的结构及相关内容。

原论文链接:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation


1. Abstract & Introduction

1.1. Abstract

卷积架构存在着固有的归纳偏差(归纳偏差指的是神经网络模型会产生具有偏好的预测结果,也就是说归纳偏差会使得学习算法优先考虑具有某些特定属性的解),它们缺乏对图像中长程依赖性的理解。

文章提出了用 transformer 来做医学图像分割。要解决的问题是,transformer 在图像任务上相比卷积神经网络需要更大的数据集来训练,而医学图像处理的一个难题就是数据不足,数据集不够大

本文主要的贡献:

  • 提出了一种适用于较小数据集的门控位置敏感轴向注意机制
  • 引入了有效的局部-全局(LOGO)训练方法

1.2. Introduction

在 ConvNets 中,每个卷积核只关注整个图像中像素的局部子集,并迫使网络关注局部模式而不是全局上下文。虽然后续提出了一些弥补的 trick 如图像金字塔、Atrus 卷积和注意机制等,仍然无法完全解决这个问题。

由于图像的背景是分散的,学习对应于背景的像素之间的 long-range dependencies 可以帮助网络防止将一个像素错误地归类为掩码,从而减少假阳性(将 0 视为背景,1 视为分割掩码)。同样,当分割遮罩很大时,学习遮罩对应的像素之间的长距离依赖关系也有助于进行有效预测。

数字图像处理中,分割掩码主要用于:

  • 提取感兴趣区,用预先制作的感兴趣区掩模与待处理图像相乘,得到感兴趣区图像,感兴趣区内图像值保持不变,而区外图像值都为0
  • 屏蔽作用,用掩模对图像上某些区域作屏蔽,使其不参加处理或不参加处理参数的计算,或仅对屏蔽区作处理或统计
  • 结构特征提取,用相似性变量或图像匹配方法检测和提取图像中与掩模相似的结构特征

本文动机:

  • 传统的CNN的卷积层缺乏对图像中远程依赖关系的建模能力(即使不断地使用池化层能够提高感受野,但是会引起大量的结构损失)。而 Transformer 在捕获长程依赖关系方面具有良好的性能。
  • 由于带标注的医学数据稀缺是一个瓶颈问题,而 Transformer 结构往往有需要大量的数据才能够取得较好的性能,所以本文提出了 Gated Axial Attention 结构来考虑解决这个问题。(主要通过在自注意模块中引入额外的控制机制来扩展现有架构)
  • 另外为了提高 Transform 的性能,文章提出了局部-全局的训练策略(具体来说,我们对整个图像和各个 patches 进行操作以分别学习全局和局部特征)

2. Medical Transformer (MedT)

2.1. Model structure

MedT 有两个分支结构,一个全局分支结构和一个局部分支结构,这两个分支的输入是从初始 conv 块提取的特征图。该块有 3 个 conv 层,每个 conv 层后面都有 batch normalizationReLU 激活函数。

网络整体结构如下图所示,为两分支的 U-shape 结构,结构中的 EncoderDecoder

  • 在两个分支的 Encoder 中,使用 Transformer
    • 即本文只用了 Transformer 机制在 U-Net 结构的 Encoder 部分的 self-attention 机制上,并且不像其它 Transformer 用于 cv 的方法一样依赖于大数据集预训练的权重,本方法不需要预训练
    • Encoder 部分如下图 (b) 所示,包括 1 × 1 1 \times 1 1×1 卷积层(后面接着一个batch normalization)和两层 multi-head attention block,其中一层沿高度轴操作,另一层沿宽轴操作,每个 multi-head attention block 由提出的门控轴向注意层组成。
      • 每个 multi-head attention block 具有 8 8 8 个门控的轴向 multi-head
      • multi-head attention block 的输出通过另一个 1 × 1 1 \times 1 1×1 卷积层被添加到残差输入图中以产生输出注意图
  • 在两个分支的 Decoder 中,使用 conv
    • 每个 Decoder 块中,有一个卷积层,其后是一个上采样层和 ReLU 激活函数
  • 在两个分支中的每个 EncoderDecoder 的块之间有 skip connections

在这里插入图片描述

2.2. Attention

2.2.1. Self-Attention Overview

具有高度 H H H、权重 W W W 和通道 $C_{in} $的输入特征映射 x ∈ R C i n × H × W x \in R^{C_{in} \times H \times W} xRCin×H×W借助投影输入,使用以下公式计算自注意力层的输出 y ∈ R C o u t × H × W y \in R^{C_{out} \times H \times W} yRCout×H×W

在这里插入图片描述

参数含义:

  • 输入 x x x 计算映射得到 queries q = W Q x q = W_Q x q=WQx, keys k = W K x k = W_K x k=WKx,values v = W V x v = W_V x v=WVx
  • q i j , k i j , v i j q_{ij}, k_{ij}, v_{ij} qij,kij,vij 表示 query,key 和 value 在任意位置 i ∈ { 1 , … , H } i \in \{ 1, \dots, H \} i{ 1,,H} j ∈ { 1 , … , W } j \in \{ 1, \dots, W \} j{ 1,,W} 的值
  • 投影矩阵 W Q , W K , W V ∈ R C i n × C o u t W_Q, W_K, W_V \in R^{C_{in} \times C_{out}} WQ,WK,WVRCin×Cout 是可学习的

自注意力机制的局限:

  • 与卷积不同,自注意力机制能够从整个特征图中捕获非局部信息,但是这种对于相似度的计算的计算量非常之大
    • ViT 提出的时候,Transformer 的每个token 会对其他所有的每个 token 都计算注意力,所以是 ( h w ) 2 (hw)^2 (hw)2 次计算,这是非常庞大的计算量。有关 ViT 的其他内容可以参考我的另一篇 blog:CV-Model【6】:Vision Transformer
  • 因为没有引入位置信息,Transformer 其实是不具有位置表达能力的。

2.2.2. Axial-Attention

在这里插入图片描述

  • 为了克服计算复杂度较高的情况,将传统的自注意力模块分为宽度上以及高度上的两个注意力模块,称为 axial attention,大大减小了计算复杂度
    • Axial attention 的感受野是目标像素的同一行(或者同一列)的 W W W (或 H H H)个像素
    • 在高度轴和宽度轴上施加的轴向注意力有效地模拟了原始的自注意机制,并具有更好的计算效率
    • 一个轴向注意力层沿着一个特定的轴传播信息,为了捕获全局信息,我们分别为高度轴和宽度轴连续使用两个轴向注意力层,且两个轴向注意力层都采用了多头注意力机制
  • 为了添加位置表达能力,需要加一个 position embedding,就是用一个 onehot 的位置向量,经过一个全连接的 embedding,产生位置编码,这个全连接是可训练的
    • 这个位置编码原先只加在 Q , K , V Q, K, V Q,K,V Q Q Q 上,现在把它同时添加到 Q , K , V Q, K, V Q,K,V 三个上面

加上轴向注意力和多个位置编码的 trick 后,注意力机制如下所示(文章中给出的是宽度方向 w w w 上的注意里,高度方向 h h h 上的注意力类似):

在这里插入图片描述

参数含义:

  • w w w 表示对应的哪一行(width)
  • y i j y_{ij} yij 表示在具体某个位置的输出
  • r q , r k , r v ∈ R W × W r^q, r^k, r^v \in R^{W \times W} rq,rk,rvRW×W 表示 width-wise axial 注意力模型中的位置矩阵

高度方向同理

2.2.3. Gated Axial-Attention

然而上述的 trick 需要大量数据集进行训练,小量的数据不足以训练 QKV 的三个 position embedding,而医学数据集多数情况下就是少量的

在这种情况下,不准确的 position embedding 会给网络准确率带来负面影响,为此文章提出了个方法用来控制这个影响的程度,修改上述公式如下:

在这里插入图片描述

这里三个 G Q , G K , G V G_Q, G_K, G_V GQ,GK,GV 都是可学习的参数,当数据集不足以使得网络预测准确的 position embedding 时,网络的 G 会小一点,反之会大一点,因此起到一个所谓的 Gated 的作用。position embedding 只要相对位置一样,对不同的样本应该是一样的,因为 position embedding 只是位置信息,没有包含语义信息

通常,如果一个相对位置编码被准确学习,相对于那些没有被准确学习的编码,门控机制会赋予它较高的权重。

2.3. Local-Global Training

Transformer 做图像分割可以用 patch-wise 的方式去做,也就是说把一张完整图片切割成多个 patch,每个 patch 和这个 patch 对应的 mask 作为一个样本,用来训练 transformer,这样十分快

然而问题在于,一张图片的一个病灶可能比一个 patch 大,这样的话这个 patch 看起来就会很奇怪,因为被病灶充满了。这限制了网络学习 patch 间像素的任何信息或依赖性

Local-Global 这个部分的思路有点像多尺度的一个思考,他将网络分成了两个分支(branch):

  • 第一个分支为 Global branch,其工作在原分辨率的图像,不做特殊处理,采取的策略是经过较少的 block(进行两次 transformer block 后送进 Decoder),得到较大的距离依赖
    • 减少了 gated axial transformer layers 的数量,是因为发现所提出的 transformer 模型的前几个块足以模拟长距离依赖关系
  • 第一个分支为 Local branch,切分成 4 × 4 4 \times 4 4×4patch,每个 patch 单独送 transformer block 前向传播,patchpatch 之间没有任何联系,最后再把这 4 × 4 4 \times 4 4×4patchfeature map 通过 concat 操作拼接在一起
    • 每个 patch 通过网络进行前向反馈,并根据其位置对输出特征图进行重新采样,以获得输出特征图

将两个分支的输出特征图相加并通过 1 × 1 1 \times 1 1×1 卷积层以产生输出分割掩码。这种在图像的全局上下文上操作较浅的模型和在 patch 上操作较深的模型的策略提高了性能,因为全局分支关注的是高层信息,而局部分支关注的是更精细的细节

2.4. Loss function

MedT 使用预测和 ground truth 之间的二进制交叉熵 (CE) 损失来训练网络

在这里插入图片描述

参数含义:

  • w , h w, h w,h 是图像的尺寸
  • p ( x , y ) p(x, y) p(x,y) 对应于图像中的像素
  • p ^ ( x , y ) \hat{p}(x, y) p^(x,y) 表示在特定位置 ( x , y ) (x, y) (x,y) 的输出预测

总结

作者探索使用基于 transformer 的编码器架构来分割医学图像,而不需要任何预训练:

  1. 提出了一种适用于较小数据集的门控位置敏感轴向注意机制,提出了一种门控的轴向注意力层,作为网络编码器多头注意力模型的构建块。
  2. 引入了有效的局部-全局 (LOGO) 训练方法,在这种策略中,作者使用相同的网络架构训练整张图像和图像patch。全局分支通过对长期依赖关系进行建模来帮助网络中学习高级功能,而局部分支则通过对patch进行操作来关注更精细的功能。
  3. 提出了以轴向关注度为主要构建块的 MedT (Medical Transformer) 作为编码器的主要构建块,并采用 LOGO 策略对图像进行训练。
  4. 在三个不同的医学图像分割数据集进行了大量的实验,提高了卷积网络和其他基于 Transformer 架构的性能。

猜你喜欢

转载自blog.csdn.net/HoraceYan/article/details/128745781