Dive into the code of the ViT model

The Vision Transformer (ViT) marks a significant milestone in the evolution of computer vision. ViT challenges the traditional view that images are best processed through convolutional layers, demonstrating that sequence-based attention mechanisms can effectively capture complex patterns, context, and semantics in images.

By decomposing images into manageable patches and leveraging self-attention, ViT captures local and global relationships, allowing it to excel in a variety of vision tasks, from image classification to object detection and more.

In this article, we’ll take a deep dive into the inner workings of ViT classification.

introduction

The core idea of ​​ViT is to treat the image as a series of fixed-size patches, which are then unfolded and converted into 1D vectors. These patches are then processed by a transformer encoder, enabling the model to capture the global context and dependencies of the entire image. By dividing images into patches, ViT effectively reduces the computational complexity of processing large images while retaining the model's ability to model complex spatial interactions.

First, we import the ViT model for classification from the Hugging Face transformer library:

from transformers import ViTForImageClassification

import torch

import numpy as np

model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

wherepatch16–224 means that the model accepts images of size 224x224, with each patch having a width and height of 16 pixels.

The following is an example of a model architecture:

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): ViTOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )

        .......

        (11): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): ViTOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  )
  (classifier): Linear(in_features=768, out_features=1000, bias=True)
)

Embed

patch embedding

Converting images into patches is performed using Conv2D layers. The Conv2D layer performs a two-dimensional convolution operation on the input data to learn the features and patterns of the image.

However, in this case, the Conv2D layer is used to divide the image into NxN number of patches based on the stride parameter. Stride determines the step size by which the filter slides over the input data.

In this case, since our image is 224x224, each patch is of size 16, which means there are 224/16 = 14 patches in each dimension. If stride=16 is chosen, we effectively divide the image into 14 non-overlapping patches.

To be more intuitive, assume that the shape of the image is 4x4 and the stride is 2:

35b85dc890786fbdc0ff176fcc8f5f3b.gif

So, for example, the first and second patches would be:

proj = model.vit.embeddings.patch_embeddings.projection
torch.allclose(torch.sum(image[0, :, 0:16, 0:16] * w[0]) + b[0],
               proj(image)[0][0][0, 0], atol=1e-6)
# True


torch.allclose(torch.sum(image[0, :, 16:32, 0:16] * w[0]) + b[0],
                 proj(image)[0][0][1, 0], atol=1e-6)

# True

The pattern is obvious - to calculate each patch, we skip 16 pixels to get non-overlapping patches. If we do this for the entire image, we end up with a 1 x 14 x 14 tensor, where each patch is represented by a number calculated by the first filter of Conv2D.

However, there are 768 filters, which means that in the end we get a 768 x 14 x 14 tensor. So now, for each patch, we actually have a 768-dimensional representation, which is our patch embedding. We also flatten and transpose the tensor so the embedding shape becomes [batch_size, 196, 768] where the second dimension is flattened to 14 x 14 = 196 and we actually have an embedding dimension of size 768 sequence.

embeddings = model.vit.embeddings.patch_embeddings.projection(image)
# shape (batch_size, 196, 768)
embeddings = embeddings.flatten(2).transpose(1, 2)

If we want to reproduce the entire layer from scratch, this is the code:

batch_size = 1 
F = 768 # number of filters
H1 = 14 # output dimension hight - 224/16
W1 = 14 # output dimension width - 224/16
stride = 16
HH = 16 # patch hight
WW = 16 # patch width

w = model.vit.embeddings.patch_embeddings.projection.weight

b = model.vit.embeddings.patch_embeddings.projection.bias

out = np.zeros((N, F, H1, W1))

chunks = []

for n in range(batch_size):
    for f in range(F):
        for i in range(H1):
            for j in range(W1):
                # perform convolution operation
                out[n, f, i, j] = torch.sum( image[n, :, i*stride:i*stride+HH, j*stride : j*stride + WW] * w[f] ) + b[f]

np.allclose(out[0], embeddings[0].detach().numpy(), atol=1e-5)
# True

Now, if you are familiar with language transformers (you can check it out here if you want), you should recall that the [CLS] tag is used as a concise and informative summary of the entire text, allowing the model to encode the data from the transformer extract features from the machine to make accurate predictions. In ViT we also have [CLS] tags which have the same functionality as text and are appended to the representation calculated above.

The [CLS] token is a parameter that we will learn using backpropagation:

cls_token = nn.Parameter(torch.randn(1, 1, 768))

cls_tokens = cls_token.expand(batch_size, -1, -1)

# append [CLS] token
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
location embedding

Just like in the language transformer, to preserve the positional information of patches, ViT includes positional embeddings. Positional embeddings help the model understand the spatial relationships between different patches, allowing it to capture the structure of the image.

The positional embedding is a tensor with the same shape as the previously computed [CLS] tokens, i.e. [batch_size, 197, 768].

embeddings = embeddings + model.vit.embeddings.position_embeddings
Dropout

After the embedding layer is a Dropout layer. In Dropout, we replace certain values ​​with zeros, with a certain probability of loss. Dropout helps reduce overfitting because we randomly mask the signals of certain neurons so that the network needs to find other paths to reduce the loss function and thus learn to generalize better instead of relying on certain paths. We can also think of Dropout as a model ensemble technique, since during training we randomly deactivate certain neurons at each step, and eventually during evaluation we merge these "different" networks.

At the end of the embedding layer we have:

# compute the embedding
embeddings = model.vit.embeddings.patch_embeddings.projection(image)
embeddings = embeddings.flatten(2).transpose(1, 2)

# append [CLS] token
cls_token = model.vit.embeddings.cls_token
embeddings = torch.cat((cls_tokens, embeddings), dim=1)

# positional embedding
embeddings = embeddings + self.position_embeddings

# droput
embeddings = model.vit.embeddings.dropout(embeddings)

Encoder

ViT uses a bunch of transformer encoding blocks, similar to those used in language models such as BERT. Each encoding block consists of multi-head self-attention and feed-forward neural networks. The self-attention mechanism enables the model to capture the relationships between different patches, while the feed-forward neural network performs non-linear transformations.

Specifically, each layer consists of self-attention, intermediate and output modules.

(0): ViTLayer(
  (attention): ViTAttention(
    (attention): ViTSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (output): ViTSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (intermediate): ViTIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
  )
  (output): ViTOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
self-attention

Self-attention is a key mechanism within the Vision Transformer (ViT) model, which enables it to capture the relationships and dependencies between different patches in an image. It plays a crucial role in extracting contextual information and understanding long- and short-range interactions between patches.

Each patch is associated with three vectors: Key, Query, and Value. These vectors are learned by linearly transforming the original patch embeddings. The key vector represents information from the current patch, the query vector is used to query information about other patches, and the value vector holds information related to other patches.

Since we already computed the embeddings in the previous section, we compute the keys, queries, and values ​​by projecting the embeddings using the key, query, and value matrices:

import math 

import torch.nn as nn

torch.manual_seed(0)

hidden_size = 768

num_attention_heads = 12

attention_head_size = hidden_size // num_attention_heads # 64

hidden_states = embeddings

# apply LayerNorm to the embeddings
hidden_states = model.vit.encoder.layer[0].layernorm_before(hidden_states)

# take first layer of the Transformer
layer_0 = model.vit.encoder.layer[0]

# shape (768, 64) 
key_matrix = layer_0.attention.attention.key.weight.T[:, :attention_head_size]

key_bias = layer_0.attention.attention.key.bias[:attention_head_size]

query_matrix = layer_0.attention.attention.query.weight.T[:, :attention_head_size] 

query_bias = layer_0.attention.attention.query.bias[:attention_head_size]

value_matrix = layer_0.attention.attention.value.weight.T[:, :attention_head_size]

value_bias = layer_0.attention.attention.value.bias[:attention_head_size]

# compute key, query and value for the first head attention
# all of shape (b_size, 197, 64)
key_1head = hidden_states @ key_matrix + key_bias

query_1head = hidden_states @ query_matrix + query_bias

value_1head = hidden_states @ value_matrix + value_bias

Note that we skipped the LayerNorm operation, we will discuss it later.

For each query vector, the attention score is calculated by measuring the compatibility or similarity between the query and the key vectors of all other patches. This is done via a dot product operation and then applying the Softmax function to obtain a normalized attention score of shape [b_size, 197, 197]. The attention matrix is ​​square because all patches pay attention to each other, that's why it is called self-attention. These scores indicate how much focus or attention should be placed on each patch when processing the query patch. Because a new embedding for each patch in the next layer is derived based on the attention score and the values ​​of all other patches, we obtain a contextual embedding for each patch as it is derived based on all other patches in the image.

To clarify this further, recall that at the beginning we split the image into patches using a Conv2D layer to obtain a 768-dimensional embedding vector for each patch - these embeddings are independent because there is no interaction between patches (no overlap). However, in the transformer layer, patch embeddings are blended together and become a function of other patch embeddings. For example, the embedding at the first layer is as follows:

# shape (b_size, 197, 197)
# compute the attention scores by dot product of query and key
attention_scores_1head = torch.matmul(query_1head, key_1head.transpose(-1, -2))

attention_scores_1head = attention_scores_1head / math.sqrt(attention_head_size)

attention_probs_1head = nn.functional.softmax(attention_scores_1head, dim=-1)

# contextualized embedding for this layer
context_layer_1head = torch.matmul(attention_probs_1head, value_1head)

If we zoom in and look at the first patch:

patch_n = 1
# shape (, 197)
print(attention_probs_1head[0, patch_n])
[2.4195e-01, 7.3293e-01, ..,
        2.6689e-06, 4.6498e-05, 1.1380e-04, 5.1591e-06, 2.1265e-05],

For its new embedding (the one with tag index 0 is the [CLS] tag), it is a combination of embeddings from different patches, focusing mainly on the first patch itself (0.73), the [CLS] tag (0.24) and the rest focusing on the others All patches. But this isn't always the case. In fact, in the next layer, the first patch may focus more on surrounding patches than on itself and the [CLS] tag, and may even focus on distant patches - depending on what the model deems useful for solving a specific task.

Also, you may have noticed that I only selected the first 64 columns from the weight matrix of query, key and value. These first 64 columns represent the first attention head, but in fact, there are 12 attention heads in this model size. Each attention head creates a different patch representation. In fact, if we look at the third attention head of the first patch, we see that the first patch pays more attention to the second patch (0.26) and not more as in the first attention head Focus on yourself.

# shape (, 197)
[2.6356e-01, 1.2783e-03, 2.6888e-01, ... , 1.8458e-02]

Therefore, different attention heads will capture the relationships between different types of patches, helping the model see things from different perspectives.

To compute all these headers in parallel we can do as follows:

def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    x = x.view(new_x_shape)
    return x.permute(0, 2, 1, 3)

mixed_query_layer = layer_0.attention.attention.query(hidden_states)

key_layer = transpose_for_scores(layer_0.attention.attention.key(hidden_states))

value_layer = transpose_for_scores(layer_0.attention.attention.value(hidden_states))

query_layer = transpose_for_scores(mixed_query_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

attention_scores = attention_scores / math.sqrt(attention_head_size)

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = layer_0.attention.attention.dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

new_context_layer_shape = context_layer.size()[:-2] + (hidden_size,)

context_layer = context_layer.view(new_context_layer_shape)

After applying self-attention, we apply another projection layer and dropout, and then complete the self-attention layer!

output_weight = layer_0.attention.output.dense.weight

output_bias = layer_0.attention.output.dense.bias

attention_output = context_layer @ output_weight.T + output_bias

attention_output = layer_0.attention.output.dropout(attention_output)

Oh, wait a minute, I promised to explain LayerNorm operations.

Layer Normalization is a normalization technique used to enhance the training and performance of deep learning models. It solves the problem of internal covariate transfer - as the weights of the neural network change during training, the input distribution of each layer can change significantly, making it difficult for the model to converge. Layer normalization solves this problem by ensuring that the input to each layer has a consistent mean and variance, thereby stabilizing the learning process. It normalizes each patch embedding by using mean and standard deviation so that it has zero mean and unit variance. We then apply the trained weights and biases to have different means and variances so that the model automatically adapts during training. Because we calculate the mean and standard deviation between different examples independently, this is different from batch normalization, which normalizes across batch dimensions and therefore depends on other examples in the batch.

Let's take the first patch embedding as an example:

first_patch_embed = embeddings[0][0]

# compute first patch mean
first_patch_mean = first_patch_embed.mean()

# compute first patch variance
first_patch_std = (first_patch_embed - first_patch_mean).pow(2).mean()

# standardize the first patch
first_patch_standardized = (first_patch_embed - first_patch_mean) / torch.sqrt(first_patch_std + 1e-12)

# apply trained weight and bias vectors
first_patch_norm = layer_0.layernorm_before.weight * first_patch_standardized + layer_0.layernorm_before.bias

In the Intermediate class we perform linear projection and apply non-linear transformation.

middle layer

Before the Intermediate class, we do another layer normalization and residual connection. Now, it should be clear why we want to apply another sub-layer normalization — we need to normalize the contextual embeddings from self-attention to improve convergence, but you might be wondering what is that other residual I mentioned?

Residual connections are a key component in deep neural networks and help alleviate the challenges of training very deep architectures. When we increase the depth of a neural network by stacking more layers, we encounter the vanishing/exploding gradient problem, where in the case of vanishing gradients the model can no longer learn because the propagated gradient approaches zero and the initial layer stops changing weight and improve performance. The opposite problem of exploding gradients is that the weights cannot be stable because there are extreme updates, eventually causing the gradient to explode (go to infinity). Now, proper weight initialization and normalization helps solve this problem, but it is observed that even if the network becomes more stable, the performance will decrease because optimization is more difficult. Adding these residual connections helps improve performance and makes the network easier to optimize even as we continue to increase depth.

How is it achieved? It's simple — we just add the original input to the output after some transformation of the original input:

transformations = nn.Sequential([nn.Linear(), nn.ReLU(), nn.Linear()])

output = input + transformations(input)

Another key insight is that if the transformation of the residual connection learns to approximate the identity function, then the addition of the input to the learned features will have no effect. In fact, the network can learn to modify or refine features if needed.

In our case, the residual connection is the sum between the initial embedding and all transformed embeddings from the attention layer (attention_output).

# first residual connection - NOTE the hidden_states are the 
# `embeddings` here
hidden_states = attention_output + hidden_states

# in ViT, layernorm is also applied after self-attention
layer_output = layer_0.layernorm_after(hidden_states)

In the Intermediate class we perform linear projection and apply nonlinearity:

layer_output_intermediate = layer_0.intermediate.dense(layer_output)

layer_output_intermediate = layer_0.intermediate.intermediate_act_fn(layer_output_intermediate)

The nonlinearity used in ViT is the GeLU activation function. It is defined as the cumulative distribution function of the standard normal distribution:

3f097cfeff2b334c3476221a7ab17522.png

Usually, to speed up calculations, it is approximated by:

0ea6bbfb717afc5ffacceabff4885755.png

From the chart below, we can see that if ReLU is given by the formula max(input, 0), it is monotonic, convex and linear in the positive domain, while GeLU is non-monotonic, convex and linear in the positive domain. Non-convex and non-linear, so complex functions can be approximated more easily.

Additionally, the GeLU function is smooth — unlike the ReLU function, which has a sharp transition at zero points, GeLU provides smooth transitions between all values, making it easier to perform gradient optimization during training.

fa2b8ddbe536acc5ec59cc345b407ffc.png
output

The last part of the encoder is the output class. To calculate the output, we already have all the elements we need — it's linear projection, dropout and residual connection:

# linear projection
output_dense = layer_0.output.dense(layer_output_intermediate)

# dropout
output_drop = layer_0.output.dropout(output_dense)

# residual connection - NOTE these hidden_states are computed in 
# Intermediate 
output_res = output_drop + hidden_states # shape (b_size, 197, 768)

Ok, we have completed the first layer of ViT Layer, there are 11 more layers to go, this is the hard part...

joke! In fact, we are done — all other layers are exactly the same as the first layer, the only difference is that, unlike the first layer, the embedding of the next layer is the output_res we calculated before.

Therefore, the output after 12 layers of the encoder is:

torch.manual_seed(0)

# masking heads in a given layer
layer_head_mask = None

# output attention probabilities
output_attentions = False

embeddings = model.vit.embeddings(image)

hidden_states = embeddings

for l in range(12):
    hidden_states = model.vit.encoder.layer[l](hidden_states, layer_head_mask, output_attentions)[0]

output = model.vit.layernorm(sequence_output)

Poolizer

Typically, in the Transformer model, the pooler is the component used to aggregate information from the token embedding sequence after the Transformer encoder block. Its role is to generate a fixed-size representation that captures global context and summarizes the information extracted from image patches, just like in ViT. Poolers are important for obtaining compact and contextual representations of images, which can then be used for various downstream tasks such as image classification.

In this case, the pooler is very simple — we take the [CLS] tag and use it as a compact and contextual representation of the image.

pooled_output = output[:, 0, :] # shape (b_size, 768)

Classifier

Finally, we are ready to use pooled_output to classify the images. The classifier is a simple linear layer whose output dimension is equal to the number of classes:

logits = model.classifier(pooled_output) # shape (b_size, num_classes)

in conclusion

ViT has revolutionized computer vision, replacing convolutional neural networks in almost every application, which is why understanding how it works is so important. Don’t forget that ViT’s main component, the Transformer architecture, has its origins in Natural Language Processing (NLP), so you should check out my previous article on BERT Transformer.

References

[1] https://github.com/huggingface/transformers

[2] [2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (arxiv.org)

☆ END ☆

If you see this, it means you like this article, please forward it and like it. Search "uncle_pn" on WeChat. Welcome to add the editor's WeChat "woshicver". A high-quality blog post will be updated in the circle of friends every day.

Scan the QR code to add the editor↓

d3e14f8735ad87b53acc96f8d580c47f.jpeg

Guess you like

Origin blog.csdn.net/woshicver/article/details/134984826