Transformer模型特辑

1. 模型结构

transformer 原文插图

  • 基本单元:token_embedding + positional encoding, encoder, token_embedding + positional encoding, decoder
  • encoder: (self-attention, skip-connect, ln), (ffn, skip-connect, ln)
  • decoder: (self-attention, skip-connect, ln), (cross-attention, skip-connect, ln), (ffn, skip-connect, ln)

2. 复杂度

3. 参数与计算量

参考文献5插图

  • 反向传播优化过程:(1)前向计算损失函数,(2)后向计算梯度,(3)优化器更新参数

开始训练一个大模型之前,根据scaling law来估算,有多少数据,需要多少算力,要计算多少时间

  • 深度学习每次前向计算,矩阵乘法就是一次加一次乘,一个parameters,要对应2次浮点计算,所以要乘以2

我们采用文献6中的约定:

  • L: Transfomer 层树
  • H:d_model, 也就是attention hidden_size维度
  • h: 多头注意力有几个attention 头
  • B: batchsize
  • S:序列的长度,比如GPT 2K,LLama2 4K
  • V: 词表里词的数量 vocab

推理显存

  • 参数,KV-cache, 中间结果
  • 中间结果占比不大: batch * token_length * embedding_size

Attention

从模型结构中拿出一个标准单元
attention, skip_connect, ln + ffn, skip_connect, ln

在这里插入图片描述

输入的embedding形状为: [B,S,H]

  • 多头注意力先把Q, K, V都dense层到H维度,[B, S, H] X [H, H] = [B, S, H], 共计算BSH^2次 x 3
  • 计算attention score, softmax(Q* K转置 / sqrt(d_model)),[B, h, S, H’] X [B, h, H’, S] = [B, h, S, S],考虑其中多头,共计算 BHS^2次
  • 与V点积,[B, h, S, S] X [B, h, S, H’] = [B, h, S, H’],共计算 BhS^2 H’ = BHS**2次
  • 经过dense线性层,多头转换回去,[B, h, S, H’] X [H’, H’] = [B, S, H],共计算BSH^2次

以上Attention过程总共计算为 2 * (3BSH2 + BSH2 + BHS2 + BHS2) = 8BSH**2 + 4BHS **2,乘以2是因为神经网络计算一次加法 一次乘法。

FFN

输入embedding形状为:[B, S, H]

  • ffn第一层,[B, S, H] x [H, 4H] = [B, S, 4H], 共计算 4BSH^2
  • ffn第二层,[B, S, 4H] x [4H, H] = [B, S, H], 共计算 4BSH^2

以上FFN总计算为 16 BSH**2

总计算量

前向计算量

  • 一个attention + ffn单元:24BSH**2 + 4 BHS **2
  • L层: L * (单元)
  • 生成:2BSHV

反向求导的时候,Loss算梯度得到新weight然后更新,所以是前向计算的两倍,乘以4

完成每个参数,都过一遍所有Token的情况下,也就是一个epoch,要经过6次浮点运算

对于Llama 65B模型推导

  • 模型参数: 65 * 10 ^9
  • token: 1.4 * 10 ^12

因此需要算力 =6 * (模型参数 * 总token)

实际算力=GPU总数单个GPU算力单个GPU利用率。
实际算力 = 2048 * 312(A100 Tflops)* 10^12() * 0.45

需要算力/实际算力 = 时间(原文21天)

显存

  • 装载模型,假如模型的参数是以FP16来计算的(A100之后BF16的居多,防止计算的时候上溢出)

  • 一个参数被表示16位的浮点数,所以它也就占用2个byte 。

  • 7B的话,静态显存占用量,指模型的所有参数被load到显存里,如果以BF16的话,要占据14个G

  • 训练过程中,除了模型参数本身外,还有梯度和优化器

    • 模型参数 + 梯度参数 + 优化器状态 + 激活 = 总显存

在一次用AdamW和混合精度训练的Epcho里,每一个模型参数,需要占用:
2byte的模型静态参数权重(以16bit存储)
2byte的模型更新参数权重(以16bit存储)
2byte的梯度(以16bit存储)
2byte的梯度更新(以16bit存储)
4byte的一阶动量优化器更新(以32bit存储)
4byte的二阶方差优化器更新(以32bit存储)

也就是: 一个模型参数需要占用16bytes的内存

更详细可以参考

4. Tokenizer

  • Byte Pair Encoding(BPE), Byte-level BPE(BBPE),Uniform Language Model(ULM),WordPiece
  • https://github.com/LongxingTan/Machine-learning-interview/blob/main/02_ml/11_nlp.md

5. Positional encoding/embedding

由于attention模型自身没有衡量位置的能力,因此需要位置编码。至于输入为什么是token_embedding + positional encoding, 可参考为什么 Bert 的三个 Embedding 可以进行相加? - 知乎

transformer论文中使用的是 positional encoding,

  • 位置编码是和token embeding一样的形状,[B, S, H]
  • 位置编码是位于[0, 1]的连续数值

旋转位置编码: RoPE

首先通过参考【4】了解Transformer中的位置编码,然后了解RoPE。最根本的变化,是在attention计算过程中,两两计算两个token的相似度期间,直接表征两个token的相对距离。
请添加图片描述
请添加图片描述

LLM学习记录(五)–超简单的RoPE理解方式 - suc16的文章 - 知乎

6. Attention 推理优化

KV-Cache

Multi-Query Attention (MQA)

  • MQA 在 encoder 上的提速没有非常明显,但在 decoder 上的提速是很显著

在这里插入图片描述

Grouped Query Attention (GQA)

在这里插入图片描述

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Attention
	def __init__(self, ...):
		self.num_key_value_heads = config.num_key_value_heads  # Group 头树 ?
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  # Group, repeat次数
		...
	
	def forward(self, ...):
		...
		query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
    
    "sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

Sliding window attention (SWA)

在这里插入图片描述

Flash Attention

  • FlashAttention主要解决Transformer计算速度慢和存储占用高的问题. 将优化重点放在了降低存储访问开销(Memory Access Cost,MAC)上

PagedAttention

在这里插入图片描述

Quantization

Decoding/sampling

Constrained sampling

Speculative decoding 投机采样

  • Accelerating Large Language Model Decoding with Speculative Sampling
  • Fast Inference from Transformers via Speculative Decoding

7. Decoder-only 推理

预训练任务

  • BERT MLM
# https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891
class MaskedLanguageModel(torch.nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, vocab_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))
  • GPT

8. 下游任务

文本分类/sentence embedding
在这里插入图片描述
last_hidden_state: [batch, maxlen, hidden_state]
pooling: cls pooling, [batch, 1, hidden_state]

class BertForSequenceClassification(BertPreTrainedModel):

NER
通过tokenizer 返回offset_mapping, 对应原始character的标注转化为token的label,进行分类任务
在这里插入图片描述

class BertForTokenClassification(BertPreTrainedModel):

QA

class BertForQuestionAnswering(BertPreTrainedModel):

参考