torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)
功能:创建一个多头注意力模块,参考论文《transformer》,参考论文及源码笔记:https://blog.csdn.net/qq_50001789/article/details/132181971
多头注意力公式为:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O MultiHead(Q,K,V)=Concat(head_1,\dots,head_h)W^O MultiHead(Q,K,V)=Concat(head1,…,headh)WO
其中 h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i=Attention(QW^Q_i,KW_i^K,VW^V_i) headi=Attention(QWiQ,KWiK,VWiV),流程图如下:
参数:
-
embed_dim
:输入数据的维度,也就是向量的长度; -
num_heads
:表示并行注意力的数量,也就是“头”的数量; -
dropout
:表示注意力权重的丢弃概率,相当于生成注意力之后,再将注意力传入一层Dropout层,默认为0; -
bias
:在做线性变换Linear时,是否添加偏置,默认True
-
add_bias_kv
:kv
做线性变换时是否加偏置,若键值维度与嵌入维度相同,则可以将add_bias_kv
设为False,默认False
-
add_zero_attn
:将一个全零注意力向量添加到最终的输出中(只影响形状,不改变数值),强制使输出张量的形状与输入张量相同 -
kdim
:keys的特征数据维度,即向量长度,默认与embed_dim
相等 -
vdim
:values的特征数据维度,即向量长度,默认与embed_dim
相等 -
batch_first
:如果设为True
,则输入、输出张量表示为(batch, seq, feature),否则张量表示为(seq, batch, feature),默认False
。
注意:
embed_dim
会被划分成num_heads
份,对应的数据也会被划分,传入不同的“head”里,每个“head”的维度是embed_dim // num_heads
;
前向传播
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)
参数:
-
query, key, value
:表示传入的qkv数据,形式因batch_first
变量而异,默认(seq, batch, feature),即(序列,batch,特征); -
key_padding_mask
:用于指定哪些位置是填充位置,以便在计算注意力权重时将其忽略。对于batch
数据,输入尺寸应为 ( N , S ) (N,S) (N,S),其中 S S S为序列长度,对于非batch
数据,输入尺寸应为 S S S,里面的数值可以是布尔、也可以是浮点数。常用布尔数据,True
表示该位置为填充,计算注意力的时候需要忽略该位置,如果传入浮点数,则会将该数与key相加,常加负数,用于抑制该位置(False与负无穷效果一样); -
need_weights
:如果指定为True
,则网络会额外输出注意力权重; -
attn_mask
:尺寸为 ( L , S ) (L,S) (L,S)或 ( N , n u m h e a d s , L , S ) (N,num_heads,L,S) (N,numheads,L,S),其中 L L L表示目标序列长度,数值表示位置, S S S表示源序列长度,数值表示位置,如果 attn_mask[b, :, i, j] 为 True,则表示第 b 个样本、第 i 个目标位置和第 j 个源位置之间需要进行注意力计算; -
average_attn_weights
:表示是否要对多头注意力中的权重沿“头”方向做平均,将多组注意力矩阵生成一组矩阵,设为True
时,表示需要做平均,即生成一个注意力矩阵,默认True,即生成每个头的注意力矩阵。只有当need_weights
设置为True时,该参数才有意义; -
is_causal
:如果 is_causal 为 True,表明目标序列中的每个位置只能依赖于它之前的位置,这个操作能够实现因果性,默认False。这个参数只作为一个提示,最终是否是因果的,还是要看参数attn_mask。
注:
-
权重由计算
k、q
的相似度得到,得到的权重再与v相乘,做加权求和; -
计算过程:
先让kqv做线性映射,之后沿特征向量的方向拆分成不同的“头”,之后利用拆分的向量做运算→q和k做矩阵乘法,得到注意力权重→注意力权重除以缩放因子 d k \sqrt{d_k} dk, d k d_k dk表示每个头的维度,再做Softmax运算→经过一次Dropout运算(可选)→所得的权重与v做矩阵乘法→合并所有“头”,最后经过一次线性映射;
- 多头是拆特征,不是拆序列;
多头注意力K、Q、V解释:
- 目前有多组键值匹配对k、v,每个k对应一个v,计算q所对应的值。思路:计算q与每个k的相似度,得到v的权重,之后对v做加权求和,得到q对应的数值。因此在解码过程中,第二个多头注意力的输入中,k、v传入编码特征(是已知的特征匹配对),q传入解码特征(可迭代传入),求解码对应的特征(根据编码特征之间的相似度求解码的注意力加权特征)。
注:kqv的关系用一句话来说就是根据kv的键值匹配关系,预测q对应的数值,根据kq的相似度对v做加权求和。
实现方法
代码来源:https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
官方文档
nn.MultiheadAttention:https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=attention#torch.nn.MultiheadAttention