1. prepare_attention_mask
这里结合Mutil Head Attention了解下不同mask的作用。key_padding_mask
和attn_mask
两个实际上都是作用到attn_output_weights来影响最终的output,前者专注处理序列中的<PAD>
,而后者专注处理序列交叉中的“不可见
”逻辑。
首先先建立一个概念:多头的分头
,分的是QKV的 dim
维度:
query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]
# 分头(split heads)
head_dim = dim // heads # heads=8
query=[batch_size, source_length, heads, dim//heads], key=[batch_size, target_len, heads, dim//heads], value=[batch_size, target_len, heads, dim//heads]
1.1 key_padding_mask
key_padding_mask
,长度是(B, S)
,B为batch_size
,S为源序列长度,即query的seq_len(NLP的token个数S/CV的patch个数HW)
,序列中没有到达max_len
的token用<PAD>
填充,key_padding_mask
中对应的位置为True
,计算attention时会将key中mask=True的部分省略掉。
query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]
计算self-attention时,key_padding_mask
只屏蔽key中的mask:即非mask的token作为query时,和sequence中所有非mask的token作为key计算self-attention;而mask的token也可以作为query,和sequence中所有非mask的token作为key计算self-attention。(mask的token不作为key token参与计算
)因为就算mask的token作为key参与了计算,最后reshape会原来的形状后,也不使用padding的部分,所以这部分注意力的计算是冗余的。
torch实现方法,
key_padding_mask
是加到attn_mask
上进行实现的。如下图的伪代码实现是(不考虑多头时 attn_mask.shape=[batch, seq_len, seq_len]
):torch.baddbmm
计算QK然后将attn_mask
加到QK矩阵上,然后mask的部分就算负无穷-inf
,再经过softmax
就变为0
.
a t t e n t i o n = S o f t m a x ( Q K T d k + a t t n _ m a s k ) ⋅ V attention=Softmax(\frac{QK^T}{\sqrt{d_k}}+attn\_mask)·V attention=Softmax(dkQKT+attn_mask)⋅V
# 模拟key_pad_mask加到attn_mask上
import torch
from einops import rearrange, repeat
batch_size, seq_len, dim = 1, 9, 8
key_pad_mask = torch.tensor([False, False, True, False, False, True, False, False, True]).unsqueeze(0)
# tensor([[False, False, True, False, False, True, False, False, True]])
key_pad_mask = torch.where(key_pad_mask, float('-inf'), 0)
# tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf]])
key_pad_mask = repeat(key_pad_mask, 'b s -> b ss s', ss=seq_len)
'''
tensor([[[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
[0., 0., -inf, 0., 0., -inf, 0., 0., -inf]]])
'''
# 假设用casual attention: 下三角attn_mask
attn_mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.where(attn_mask, float('-inf'), 0) # attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
'''
tensor([[[-inf, 0., 0., 0., 0., 0., 0., 0., 0.],
[-inf, -inf, 0., 0., 0., 0., 0., 0., 0.],
[-inf, -inf, -inf, 0., 0., 0., 0., 0., 0.],
[-inf, -inf, -inf, -inf, 0., 0., 0., 0., 0.],
[-inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0.],
[-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0.],
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0.],
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]]])
'''
attn_mask += key_pad_mask
'''
tensor([[[-inf, 0., -inf, 0., 0., -inf, 0., 0., -inf],
[-inf, -inf, -inf, 0., 0., -inf, 0., 0., -inf],
[-inf, -inf, -inf, 0., 0., -inf, 0., 0., -inf],
[-inf, -inf, -inf, -inf, 0., -inf, 0., 0., -inf],
[-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
[-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., -inf],
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]]])
'''
# query=[batch, seq_len, dim], key=[batch, tgt_len, dim], value=[batch, tgt_len, dim]
attn_score = torch.softmax(torch.baddbmm(attn_mask, query, key.transpose(-2, -1)), dim=-1)
attn_output = torch.bmm(attn_score, value)
1.2 attn_mask
attn_mask
,长度是(B, source_length, target_length)
,其中B表示batch_size
,source_length
表示源序列长度(Q的seq_len),target_length
是目标序列长度(KV的seq_len),表示对权重矩阵做mask;
query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]
如果考虑多头,则要在scaled_dot_product_attention
之前,把attn_mask为每个head复制一份(diffusers中使用prepare_attention_mask
函数实现):
- 如果attn_mask的shape是4维度的,初始
(batch, source_length, target_length)
,则unseqeeze出一个head维度,沿第1维度(heads维度)复制heads
份,变成(batch, heads, source_length, target_length)
。 - 如果attn_mask的shape是3维度的,初始
(batch, source_length, target_length)
,直接将注意力掩码沿着第0维度(batch维度)重复head_size
次,变成(batch x heads, source_length, target_length)
。
这样共batch x heads
个头做 [source_length, target_length]@[target_length, source_length]
的矩阵乘法后,分别相同batch的head使用相同的attn_mask,然后再进行softmax。
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
2. AttnProcessor
用于执行 self-attention 或 cross-attention:
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states