Hands-on deep learning (50) - multi-head attention mechanism

1. Why use multi-head attention mechanism

  • the so-calledself-attention mechanismIt is to directly calculate the attention weight of each position of the sentence in the encoding process through some kind of operation; and then calculate the hidden vector representation of the entire sentence in the form of weight sum.

  • The defect of the self-attention mechanism is that when the model encodes the information of the current location, it will excessively focus on its own location, so the author proposes a multi-head attention mechanism to solve this problem

2. What is the multi-head attention mechanism

  In practice, when given the same set of queries, keys, and values, we hope that the model can learn different behaviors based on the same attention mechanism, and then combine different behaviors as knowledge, such as capturing various Scope dependencies (for example, short-distance dependencies and long-distance dependencies). Therefore, it may be beneficial to allow attention mechanisms to combine different representation subspaces .

  For this, instead of using a separate attention pooling, we can independently learn hhA set of h different linearprojections to transform queries, keys and values. Then, thishhh sets of transformed queries, keys and values ​​will be attention pooled in parallel. Finally, put thishhThe outputs of the h attention pooling are concatenated and transformed by another learnable linear projection to produce the final output. This design is called multi-head, wherehhEach of the h attention pooling outputs is called a head Vaswani.Shazeer.Parmar.ea.2017. The figure below shows multi-head attention using fully connected layers to achieve a learnable linear transformation.

3. Multi-head attention mechanism model and theoretical calculation

  Before implementing multi-head attention, let's formalize this model in mathematical language. Given a query q ∈ R dq \mathbf{q} \in \mathbb{R}^{d_q}qRdq、 共k ∈ R dk \mathbf{k} \in \mathbb{R}^{d_k}kRdk和值v ∈ R dv \mathbf{v} \in \mathbb{R}^{d_v}vRdv, each attention head hi \mathbf{h}_ihi ( i = 1 , … , h i = 1, \ldots, h i=1,,h ) is calculated as

hi = f (W(q)q, W(k)k, W(v)v) ∈ R pv, \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q , \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv,

Among them, the learnable parameters include W i ( q ) ∈ R pq × dq \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}Wi(q)Rpq×dq W i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} Wi(k)Rpk×dkW i ( v ) ∈ R pv × dv \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}Wi(v)Rpv×dv, and the function ff representing attention poolingf can be additive attention and scaled "dot-product" attention. The output of multi-head attention needs to undergo another linear transformation, which corresponds tohhThe result after splicing h heads, so its learnable parameter isW o ∈ R po × hpv \mathbf W_o\in\mathbb R^{p_o\times h p_v}WoRpo×hpv

W o [ h 1 ⋮ hh ] ∈ R po \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.Woh1hhRpo.

Based on this design, each head may focus on a different part of the input. More complex functions than a simple weighted average can be represented.

Multi-head attention with mask:

  • When the decoder outputs an element in the sequence, it should not consider the elements after the element
  • Realized by mask: in calculating xi x_ixiOn output, pretend the current sequence is of length i

The multi-head Attention at the micro level can be expressed as:

4. Hands-on implementation of the multi-head attention mechanism layer

  In our implementation, we chose scaled "dot-product" attention as each attention head. To avoid significant growth in computational cost and number of parameters, we set pq = pk = pv = po / h p_q = p_k = p_v = p_o / hpq=pk=pv=po/ h . It is worth noting that if we set the number of outputs of the linear transformation of query, key and value aspqh = pkh = pvh = po p_q h = p_k h = p_v h = p_opqh=pkh=pvh=po, then hh can be calculated in parallelH head. In the implementation below,pop p_opois num_hiddensspecified by the parameter .

import math
import torch
from torch import nn
from d2l import torch as d2l
def transpose_qkv(X,num_heads):
    # 输入 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`).
    # 输出 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_heads`,`num_hiddens` / `num_heads`)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出 `X` 的形状: (`batch_size`, `num_heads`, 查询或者“键-值”对的个数,`num_hiddens` / `num_heads`)
    X = X.permute(0, 2, 1, 3)

    # `output` 的形状: (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,`num_hiddens` / `num_heads`)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X,num_heads):
    """逆转 `transpose_qkv` 函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    def __init__(self,key_size,query_size,value_size,num_hiddens,
                num_heads,dropout,bias=False,**kwargs):
        super(MultiHeadAttention,self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size,num_hiddens,bias=bias) # 将输入映射为(batch_size,query_size/k-v size,num_hidden)大小的输出
        self.W_k = nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v = nn.Linear(value_size,num_hiddens,bias=bias)
        self.W_o = nn.Linear(num_hiddens,num_hiddens,bias=bias)
    
    def forward(self,queries,keys,values,valid_lens):
        # `queries`, `keys`, or `values` 的形状:
            # (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`)
        # `valid_lens` 的形状:
            # (`batch_size`,) or (`batch_size`, 查询的个数)
        # 经过变换后,输出的 `queries`, `keys`, or `values` 的形状:
            # (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,`num_hiddens` / `num_heads`)
        queries = transpose_qkv(self.W_q(queries), self.num_heads) 
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads) # 将多个头的数据堆叠在一起,然后进行计算,从而不用多次计算
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens,
                                                repeats=self.num_heads,
                                                dim=0)
        output = self.attention(queries,keys,values,valid_lens) # output->(10,4,20)
#         return output
        output_concat = transpose_output(output,self.num_heads) # output_concat -> (2,4,100)
        return self.W_o(output_concat)

Let's test our MultiHeadAttentionclass . The shape of the multi-head attention output is ( batch_size, num_queries, num_hiddens).

# 线性变换的输出为100个,5个头
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
# 2个batch,4个query,6个键值对
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens)) # query(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) # key和value (2,6,100)
output = attention(X, Y, Y, valid_lens) # 输出大小与输入的query的大小相同
output.shape
torch.Size([2, 4, 100])

summary

  • Multi-head attention fuses different knowledge generated from the same attention pooling, which differs from different subspace representations of the same query, key and value.
  • Based on appropriate tensor operations, parallel computation of multi-head attention can be achieved.

practise

  1. Visualize the attention weights of multiple heads in this experiment separately.
  2. Suppose we already have a trained multi-head attention-based model and now want to prune the least important attention heads to improve prediction speed. How should experiments be designed to measure the importance of attention heads?

Guess you like

Origin blog.csdn.net/jerry_liufeng/article/details/123054063