Sine and cosine trigonometric function position coding explanation and code implementation



1. Explanation of position coding of sine and cosine trigonometric functions

In Transformer, position encoding is to introduce position information, and the form of position encoding is usually a combination of a sine function and a cosine function. The formula is as follows:
Calculation formula

Among them, PE(pos,i)​ represents the pos-th position and the value of the i-th dimension in the position encoding matrix; dmodel​ represents the dimension of the model embedding vector; i represents the i-th dimension in the position encoding matrix value. This positional encoding method can introduce positional information so that the Transformer model can process sequence data.
Assuming that the sequence length is 4 and the position encoding dimension is 6, the position encoding matrix is ​​as follows:
Insert image description here
Among them, the part in the brackets of the trigonometric function can be split into Two parts, the first part can be understood as x, and the second part can be understood as the period (the period T of the ordinary trigonometric function sin(2ΠX) is 2Π, and X is the dependent variable).
Analysis by column: For example, the column period T of dim0 is Insert image description here
A period of X from 0 to 3 is a fixed value Trigonometric function;
Analysis by row:
For example, in the row pos0, the period changes every two elements, and X is Increasing sequence; so the position code of each pos in rows is a trigonometric function with a variable period (T);

2. Code implementation

The code is as follows (example):
1. Implement the matrix in the above table:

import torch
def creat_pe_absolute_sincos_embedding(n_pos_vec, dim):
  assert dim % 2 == 0, "wrong dim"
  position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)

  omega = torch.arange(dim//2, dtype=torch.float)
  omega /= dim/2.
  omega = 1./(10000**omega)

  sita = n_pos_vec[:,None] @ omega[None,:]
  emb_sin = torch.sin(sita)
  emb_cos = torch.cos(sita)

  position_embedding[:,0::2] = emb_sin
  position_embedding[:,1::2] = emb_cos

  return position_embedding

2. Initialize the sequence length and position encoding dimensions, and calculate the position encoding matrix:

n_pos = 512
dim = 768
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
pe = creat_pe_absolute_sincos_embedding(n_pos_vec, dim)
print(pe)
tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  8.2843e-01,  ...,  1.0000e+00,
          1.0243e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.2799e-01,  ...,  1.0000e+00,
          2.0486e-04,  1.0000e+00],
        ...,
        [ 6.1950e-02,  9.9808e-01,  5.3552e-01,  ...,  9.9857e-01,
          5.2112e-02,  9.9864e-01],
        [ 8.7333e-01,  4.8714e-01,  9.9957e-01,  ...,  9.9857e-01,
          5.2214e-02,  9.9864e-01],
        [ 8.8177e-01, -4.7168e-01,  5.8417e-01,  ...,  9.9856e-01,
          5.2317e-02,  9.9863e-01]])

3. Visualize the position encoding matrix by row:

# 不同pos
import matplotlib.pyplot as plt
x = [i for i in range(dim)]
for index, item in enumerate(pe):
  if index % 50 != 1:
    continue
  y = item.tolist()
  plt.plot(x, y, label=f"数据 {index}")
  plt.show()

Print at an interval of 50. Since the sequence length is 512, the curves of 11 pos positions can be printed. The following figure shows the position encoding curves at pos0, pos250, and pos500:
Insert image description here

4. Visualize the position encoding matrix by column:

# 不同dim
x = [i for i in range(n_pos)]
for index, item in enumerate(pe.transpose(0, 1)):
  if index % 50 != 1:
    continue
  y = item.tolist()
  plt.plot(x, y, label=f"数据 {index}")
  plt.show()

Print at an interval of 50. Since the sequence length is 768, the curves of 16 pos positions can be printed. The following figure shows the position encoding curves at dim0, dim350, and dim750:
Insert image description here

Guess you like

Origin blog.csdn.net/Brilliant_liu/article/details/135033645