Vision Transformer (ViT)及各种变体

目录

0.Vision Transformer介绍

1.ViT 模型架构

1.1 Linear Projection of Flattened Patches

1.2 Transformer Encoder

1.3 MLP Head

1.4 ViT架构图

1.5 model scaling

2.Hybrid ViT

4.其他Vision Transformer变体

5.Vit代码

6.参考博文


0.Vision Transformer介绍

2017年Vaswani等人在发表的《Attention Is All You Need》中提出Transformer模型,是第一个完全依靠自注意力计算其输入和输出的模型,从此在自然语言处理领域大获成功。

2021年Dosovitskiy等人将注意力机制的思想应用于计算机视觉领域,提出了Vision Transformer(ViT )模块。在大规模数据集的支持下,ViT模型可以达到与CNNs模型相当的精度,如下图所示为ViT的不同版本与ResNet和EfficientNet在不同数据集下的准确率对比。

论文名称:《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》 

论文地址https://arxiv.org/abs/2010.11929

1.ViT 模型架构

作者在文中给出ViT模型如下架构图,其中主要有三个部分组成:

1)Linear Projection of Flattened Patches(Embedding层,将子图映射为向量);

2)Transformer Encoder(编码层,对输入的信息进行计算学习);

3)MLP Head(用于分类的层结构);

1.1 Linear Projection of Flattened Patches

在标准的Transformer模块中,输入是Token(向量)序列,即二维矩阵[num_token, token_dim]。而图像数据格式为[H, W, C]的三维数据,因此需要将图像数据经过Embedding层进行变换,转换为Transformer模块能够输入的数据类型。

 以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到( 224 / 16 ) * ( 224 / 16 ) =196个Patches。接着通过线性映射(Linear Projection)将每个Patch映射到一维向量中。

Linear Projection:使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现线性映射,这个卷积操作产生shape变化为[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平(Flattened Patches)即可,shape变化为([14, 14, 768] -> [196, 768]),此时正好变成了一个二维矩阵,符合Transformer输入的需求。其中,196表征的是patches的数量,将每个Patche数据shape为[16, 16, 3]通过卷积映射得到一个长度为768的向量(后面都直接称为token)。

在输入Transformer Encoder之前注意需要加上[class]token以及Position Embedding

1)[class]token:原文中,作者参考了Bert模型,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。

2)Position Embedding:Position Embedding采用的是一个可训练的一维位置编码(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。

自注意力是所有的元素两两之间去做交互,所以是没有顺序的,但是图片是一个整体,子图patches是有自己的顺序的,在空间位置上是相关的,所以要给patch embedding加上了positional embedding这样一组位置参数,让模型自己去学习patches之间的空间位置相关性。

对于Position Embedding作者也有做一系列对比试验,虽然没有位置嵌入的模型和有位置嵌入的模型的性能有很大差距,但是不同的位置信息编码方式之间几乎没有差别,由于Transformer编码器工作在patch级别的输入上,相对于pixel级别,如何编码空间信息的差异不太重要,结果如下所示:

1.2 Transformer Encoder

Transformer Encoder 主要由以下几部分组成:

1)Layer Norm:Transformer中使用Layer Normalization进行归一化操作,能够加快训练的速度,提高训练的稳定性;

2)Multi-Head Attention:与Transformer中的一样,详见:Transformer-《Attention Is All You Need》_HM-hhxx!的博客-CSDN博客

3)Dropout/DropPath:在原论文的代码中是直接使用的Dropout层,在但实现的代码中使用的是DropPath;

4)MLP Block:全连接+GELU激活函数+Dropout组成, 第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768],MLPBlock结构如下图所示

 

Encoder结构如下图所示,左侧为实际结构,右侧为论文中结构,省去了Dropout/DropPath层,其实就是重复堆叠如下图所示的Encoder Block L次,MLP Block结构如上图所示:

   

1.3 MLP Head

在经过Transformer Encoder时,输入的shape和输出的shape保持不变。在论文中,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。在Transformer Encoder后还有一个Layer Norm,结构图中并没有给出,如下图所示:

这里我们只是需要Transformer Encoder中的分类信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768],因为self-attention计算全局信息的特征,这个[class]token其中已经融合了其他token的信息。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。

1.4 ViT架构图

1.5 model scaling

论文中,作者根据Bert模型设计了‘Base’和‘Large’模型,并增加了一个‘Huge’模型。并对名称进行了解释,例如ViT-L / 16表示具有16 × 16输入patch size的" Large "变体。需要注意的是Transformer的序列长度与patch size的平方成反比,因此patch size较小的模型计算开销较大。因此在ViT源码中,除了patch size为16×16的,还有32×32的。

下表中的Layers就是Transformer Encoder中重复堆叠Encoder Block的次数,Hidden Size就是对应通过Embedding层后每个token的dim(向量的长度),MLP size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍),Heads代表Transformer中Multi-Head Attention的heads数。

2.Hybrid ViT

论文在4.1章节中介绍模型的缩放后,对模型的混合模型进行了介绍,即将传统的CNN特征提取和Transformer进行结合。文中将以ResNet50作为特征提取器的混合模型,但这里的R50的卷积层采用的StdConv2d不是传统的Conv2d,然后将所有的BatchNorm层替换成GroupNorm层。在原Resnet50网络中,stage1重复堆叠3次,stage2重复堆叠4次,stage3重复堆叠6次,stage4重复堆叠3次,但在这里的R50中,把stage4中的3个Block移至stage3中,所以stage3中共重复堆叠9次。

通过R50 Backbone进行特征提取后,得到的特征矩阵shape是[14, 14, 1024],接着再输入Patch Embedding层,注意Patch Embedding中卷积层Conv2d的kernel_size和stride都变成了1,只是用来调整channel。后面的部分和前面的ViT结构一样。

 下表是论文中对比ViT、ResNet及R-ViT模型的效果,通过对比发现,在训练epoch较少时hybrid模型效果优于ViT,但在epoch增加时ViT效果更好。

4.其他Vision Transformer变体

在论文《A review of convolutional neural network architectures and their optimizations》中,作者指出一些研究表明ViT模型与CNN相比缺乏可优化性,这是由于ViT缺乏空间归纳偏差。因此,在ViT模型中使用卷积策略来削弱这种偏差,可以提高其稳定性和性能。并列出如下Vit变体:

1)LeVit(2021):映入主义偏向的思想来结合位置信息。

论文名称:《LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference》

论文地址:https://arxiv.org/abs/2104.01136

2)PVT(2021):金字塔vit,将transformer融入到CNNs中,在图像的密集分区上进行训练,以实现输出高分辨率。克服了transformer对于密集预测任务的缺点。

 论文名称:《Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions》

论文地址:https://arxiv.org/abs/2102.12122

3)T2T-ViT(2021):通过递归地将相邻的token聚合为一个token,图像最终被逐步结构化为token;提供了具有更深更窄的高效backbone;将图像结构化为token。

论文名称:《Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet》

论文地址:https://arxiv.org/abs/2101.11986

4)MobileVit(2021):将mobilenet v2连接vit,效果明显优于其他轻量级网络。结合逆残差(inverse residual)和Vit。

论文名称:《MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer》

论文地址:https://arxiv.org/abs/2110.02178

5)VTs(2021):Visual Transformers,通过词法分析将特征图转换为一系列视觉token,然后通过投影仪( Wu等2020b)将处理后的视觉令牌投影到原始地图和原始图像上。实验表明,VTs在使用较少的FLOPs和参数的情况下,将ImageNet top - 1的ResNet精度提高了4.6 ~ 7个点。通过映射将图像输入transformer。

论文名称:《Visual Transformers: Token-based Image Representation and Processing for Computer Vision》

论文地址:https://arxiv.org/abs/2006.03677

6)Conformer(2021):结合CNN与Transformer的优点,并行的通过特征耦合单元(Feature Coupling Unit,FCU)与每个阶段的局部和全局特征进行交互,从而兼具CNN和Transformer的优点。将CNNs和Transformer模块并行组合。

论文名称:《Conformer: Local Features Coupling Global Representations for Visual Recognition》

论文地址:https://arxiv.org/abs/2105.03889

7)BoTNet(2021):通过在ResNet的最后三个瓶颈块中用全局自注意力替换空间卷积显著改善了基线。用全局自注意力代替空间卷积。

论文名称:《Bottleneck Transformers for Visual Recognitiont》

论文地址:CVPR 2021 Open Access Repository

8)CoAtNets(2021):并认为CNNs的深层结构和注意力机制可以通过简单的相对注意力联系起来。此外还认为叠加卷积层和Transformer encoder可以产生好的效果。堆叠卷积层和Transformer编码器。

论文名称:《CoAtNet: Marrying Convolution and Attention for All Data Sizes》

论文地址:https://arxiv.org/abs/2106.04803

9)Swin Transformer(2021):通过移动窗口将自注意力计算限制在不重叠的局部窗口,同时允许跨窗口连接,在Imagenet - 1K上达到了87.3 %的准确率。将自注意力限制在不重叠的局部窗口并将其进行连接。

论文名称:《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》

论文地址:https://arxiv.org/abs/2103.14030

5.Vit代码

原始Vit代码地址:

pytorch-image-models/vision_transformer.py at main · huggingface/pytorch-image-models · GitHub

model.py:

"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    # work with diff dim tensors, not just 2D ConvNets
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + \
        torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0],
                          img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(
            in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)  # [B,196,768]
        x = self.norm(x)
        return x


class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5  # 开根号操作
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
                                  self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # q*k的转置*缩放因子,缩放因子就是根号下dk
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# encoder block


class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(
            drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(
            img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(
            1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(
            1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[
                      i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(
            self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(
                self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(
                x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        # 图片分类没走if这块
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x


def _init_vit_weights(m):
    """
    ViT weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)


def vit_base_patch16_224(num_classes: int = 1000):
    """
    ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA  密码: eu9f
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_base_patch32_224(num_classes: int = 1000):
    """
    ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg  密码: s5hl
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_large_patch16_224(num_classes: int = 1000):
    """
    ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ  密码: qqt8
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    NOTE: converted weights not currently available, too large for github release hosting.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=14,
                              embed_dim=1280,
                              depth=32,
                              num_heads=16,
                              representation_size=1280 if has_logits else None,
                              num_classes=num_classes)
    return model

6.参考博文

1.深度学习之图像分类(十二): Vision Transformer - 魔法学院小学弟

2. Vision Transformer详解_太阳花的小绿豆的博客-CSDN博客

3.论文《A review of convolutional neural network architectures and their optimizations》 A review of convolutional neural network architectures and their optimizations | SpringerLink

猜你喜欢

转载自blog.csdn.net/damadashen/article/details/130956677