1. 模型结构
- 基本单元:token_embedding + positional encoding, encoder, token_embedding + positional encoding, decoder
- encoder: (self-attention, skip-connect, ln), (ffn, skip-connect, ln)
- decoder: (self-attention, skip-connect, ln), (cross-attention, skip-connect, ln), (ffn, skip-connect, ln)
2. 复杂度
3. 参数与计算量
- 反向传播优化过程:(1)前向计算损失函数,(2)后向计算梯度,(3)优化器更新参数
开始训练一个大模型之前,根据scaling law来估算,有多少数据,需要多少算力,要计算多少时间
- 深度学习每次前向计算,矩阵乘法就是一次加一次乘,一个parameters,要对应2次浮点计算,所以要乘以2
我们采用文献6中的约定:
- L: Transfomer 层树
- H:d_model, 也就是attention hidden_size维度
- h: 多头注意力有几个attention 头
- B: batchsize
- S:序列的长度,比如GPT 2K,LLama2 4K
- V: 词表里词的数量 vocab
推理显存
- 参数,KV-cache, 中间结果
- 中间结果占比不大: batch * token_length * embedding_size
Attention
从模型结构中拿出一个标准单元
attention, skip_connect, ln
+ ffn, skip_connect, ln
输入的embedding形状为: [B,S,H]
- 多头注意力先把Q, K, V都dense层到H维度,[B, S, H] X [H, H] = [B, S, H], 共计算BSH^2次 x 3
- 计算attention score, softmax(Q* K转置 / sqrt(d_model)),[B, h, S, H’] X [B, h, H’, S] = [B, h, S, S],考虑其中多头,共计算 BHS^2次
- 与V点积,[B, h, S, S] X [B, h, S, H’] = [B, h, S, H’],共计算 BhS^2 H’ = BHS**2次
- 经过dense线性层,多头转换回去,[B, h, S, H’] X [H’, H’] = [B, S, H],共计算BSH^2次
以上Attention过程总共计算为 2 * (3BSH2 + BSH2 + BHS2 + BHS2) = 8BSH**2 + 4BHS **2,乘以2是因为神经网络计算一次加法 一次乘法。
FFN
输入embedding形状为:[B, S, H]
- ffn第一层,[B, S, H] x [H, 4H] = [B, S, 4H], 共计算 4BSH^2
- ffn第二层,[B, S, 4H] x [4H, H] = [B, S, H], 共计算 4BSH^2
以上FFN总计算为 16 BSH**2
总计算量
前向计算量
- 一个attention + ffn单元:24BSH**2 + 4 BHS **2
- L层: L * (单元)
- 生成:2BSHV
反向求导的时候,Loss算梯度得到新weight然后更新,所以是前向计算的两倍,乘以4
完成每个参数,都过一遍所有Token的情况下,也就是一个epoch,要经过6次浮点运算
对于Llama 65B模型推导
- 模型参数: 65 * 10 ^9
- token: 1.4 * 10 ^12
因此需要算力 =6 * (模型参数 * 总token)
实际算力=GPU总数单个GPU算力单个GPU利用率。
实际算力 = 2048 * 312(A100 Tflops)* 10^12() * 0.45
需要算力/实际算力 = 时间(原文21天)
显存
-
装载模型,假如模型的参数是以FP16来计算的(A100之后BF16的居多,防止计算的时候上溢出)
-
一个参数被表示16位的浮点数,所以它也就占用2个byte 。
-
7B的话,静态显存占用量,指模型的所有参数被load到显存里,如果以BF16的话,要占据14个G
-
训练过程中,除了模型参数本身外,还有梯度和优化器
模型参数 + 梯度参数 + 优化器状态 + 激活 = 总显存
在一次用AdamW和混合精度训练的Epcho里,每一个模型参数,需要占用:
2byte的模型静态参数权重(以16bit存储)
2byte的模型更新参数权重(以16bit存储)
2byte的梯度(以16bit存储)
2byte的梯度更新(以16bit存储)
4byte的一阶动量优化器更新(以32bit存储)
4byte的二阶方差优化器更新(以32bit存储)
也就是: 一个模型参数需要占用16bytes的内存
更详细可以参考
4. Tokenizer
- Byte Pair Encoding(BPE), Byte-level BPE(BBPE),Uniform Language Model(ULM),WordPiece
- https://github.com/LongxingTan/Machine-learning-interview/blob/main/02_ml/11_nlp.md
5. Positional encoding/embedding
由于attention模型自身没有衡量位置的能力,因此需要位置编码。至于输入为什么是token_embedding + positional encoding, 可参考为什么 Bert 的三个 Embedding 可以进行相加? - 知乎
transformer论文中使用的是 positional encoding,
- 位置编码是和token embeding一样的形状,[B, S, H]
- 位置编码是位于[0, 1]的连续数值
旋转位置编码: RoPE
首先通过参考【4】了解Transformer中的位置编码,然后了解RoPE。最根本的变化,是在attention计算过程中,两两计算两个token的相似度期间,直接表征两个token的相对距离。
LLM学习记录(五)–超简单的RoPE理解方式 - suc16的文章 - 知乎
6. Attention 推理优化
KV-Cache
- KV cache主要分成5个方向的优化,即Sparse、Quantization、Allocator、Window、share
- 关于为什么Q不需要缓存,可参考为什么加速LLM推断有KV Cache而没有Q Cache? - 知乎
- KC cache计算量,显存分析,可参考KV cache详解 图示,显存,计算量分析,代码 - 莫笑傅立叶的文章 - 知乎
Multi-Query Attention (MQA)
- MQA 在 encoder 上的提速没有非常明显,但在 decoder 上的提速是很显著
Grouped Query Attention (GQA)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Attention
def __init__(self, ...):
self.num_key_value_heads = config.num_key_value_heads # Group 头树 ?
self.num_key_value_groups = self.num_heads // self.num_key_value_heads # Group, repeat次数
...
def forward(self, ...):
...
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
Sliding window attention (SWA)
Flash Attention
- FlashAttention主要解决Transformer计算速度慢和存储占用高的问题. 将优化重点放在了降低存储访问开销(Memory Access Cost,MAC)上
PagedAttention
Quantization
Decoding/sampling
Constrained sampling
Speculative decoding 投机采样
- Accelerating Large Language Model Decoding with Speculative Sampling
- Fast Inference from Transformers via Speculative Decoding
7. Decoder-only 推理
- How continuous batching enables 23x throughput in LLM inference while reducing p50 latency
- 关于 prefill cache, 可参考原理&图解vLLM Automatic Prefix Cache(RadixAttention): 首Token时延优化 - DefTruth的文章 - 知乎
预训练任务
- BERT MLM
# https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891
class MaskedLanguageModel(torch.nn.Module):
"""
predicting origin token from masked input sequence
n-class classification problem, n-class = vocab_size
"""
def __init__(self, hidden, vocab_size):
"""
:param hidden: output size of BERT model
:param vocab_size: total vocab size
"""
super().__init__()
self.linear = torch.nn.Linear(hidden, vocab_size)
self.softmax = torch.nn.LogSoftmax(dim=-1)
def forward(self, x):
return self.softmax(self.linear(x))
- GPT
8. 下游任务
文本分类/sentence embedding
last_hidden_state: [batch, maxlen, hidden_state]
pooling: cls pooling, [batch, 1, hidden_state]
class BertForSequenceClassification(BertPreTrainedModel):
NER
通过tokenizer 返回offset_mapping, 对应原始character的标注转化为token的label,进行分类任务
class BertForTokenClassification(BertPreTrainedModel):
QA
class BertForQuestionAnswering(BertPreTrainedModel):
参考
- https://jalammar.github.io/illustrated-transformer/
- http://nlp.seas.harvard.edu/annotated-transformer/
- https://github.com/Kyubyong/transformer
- https://kipp.ly/transformer-taxonomy/
- Transformer学习笔记一:Positional Encoding(位置编码) - 猛猿的文章 - 知乎
- 浅谈后向传递的计算量大约是前向传递的两倍 - 回旋托马斯x的文章 - 知乎
- LLM 参数,显存,Tflops? 训练篇(1) - 周博洋的文章 - 知乎
- Llama源码深入解析 - 一个有毅力的吃货的文章 - 知乎
- 十分钟读懂旋转编码(RoPE) - 绝密伏击的文章 - 知乎
- 大模型推理性能优化之KV Cache解读 - Young的文章 - 知乎
- Muti Query Attention 和 Attention with Linear Bias(附源码) - 何枝的文章 - 知乎
- DistServe速读——Prefill & Decode解耦、模型并行策略&GPU资源分配解耦 - 阿杰的文章 - 知乎
- https://github.com/alibaba/Megatron-LLaMA
- 稀疏注意力计算:sliding window attention - Linsight的文章 - 知乎
- 大模型推理加速:KV Cache Sparsity(稀疏化)方法 - 歪门正道的文章 - 知乎