论文标题:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
论文地址:https://arxiv.org/abs/2305.13245
【注】阿里的开源模型Qwen自Qwen2、Meta的开源模型LlaMa自LlaMa3,都使用了GQA来加速推理。GQA是一种可以加速模型推理但又不会明显降低推理质量的一种注意力机制。
我们都知道现在大语言模型都是基于Transformer的Block组成的,而Block的核心是注意力机制。对于注意力机制,在阅读本篇论文之前,有三个名词要加以区分:
- 多头注意力(Multi-Head Attention, MHA)
- 多查询注意力(Multi-Query Attention, MQA)
- 分组查询注意力(Grouped-Query Attention, GQA)
如图1所示,多头注意力(MHA)是传统的注意力机制,每个头都有独立的查询(Q)、键(K)、值(V)投影矩阵。多查询注意力(MQA)是所有头共享一组K、V投影矩阵,每个头只有独立的Q投影矩阵。分组查询注意力(GQA)将查询头分为G组,每组共享一组K、V投影矩阵。
早在2019年,MQA就已经被提出,但是MQA的推理速度虽然快,但是生成质量会明显下降。因此,论文提出了GQA,在推理速度和生成质量之间取得平衡。
【注】GQA论文:https://arxiv.org/pdf/1911.02150
Abstract
前人提出的多查询注意力(Multi-query attention, MQA)虽然可以加快推理速度,但是生成质量会明显下降。因此论文提出了分组查询注意力(grouped-query attention, GQA),该方法只需使用模型预训练时5%的计算资源,就能将现有的多头注意力(MHA)模型checkpoint转换为GQA模型。
实验表明,GQA在保持与MHA相当的生成质量的同时,能够达到与MQA相近的推理速度。
Introduction
自回归模型在推理时需要加载所有注意力层的键值缓存,这需要强大的性能和内存来保证。为了解决这个问题,研究人员提出了多查询注意力(MQA)机制来加速推理,但MQA存在生成质量下降和训练不稳定的缺陷。本文的主要贡献有两点:
-
证明了已有的多头注意力(multi-head attention, MHA)模型checkpoint可以通过简单的改造转换为MQA模型;
-
提出了分组查询注意力(GQA)机制,该方法可以在保持MHA的生成质量的同时,实现接近MQA的推理速度。
【注】自回归(Decoder Only)与自编码(Encoder Only)是两种不同的模型架构。自回归模型如GPT、LlaMa、Qwen等主要用于文本生成,而自编码模型如BERT则更适合语义理解任务。目前大语言模型领域主要以自回归架构为主,因为自编码模型虽然在语义提取方面有优势,但不太适合做生成任务。
Method
升杯!
这里论文主要介绍的是把用MHA训练的模型的checkpoint转换为GQA模型。
主要有2步:
1)转换checkpoint,把键、值的投影矩阵都通过平均池化转换为1个投影矩阵,如图2;
2)然后就是继续预训练。
GQA
如图1所示,分组查询注意力(GQA)采用了一种灵活的分组策略 - 将查询头(query heads)划分为G个组,每个组内共享同一组键头(key head)和值头(value head)。因此:
- GQA-G表示将查询头分为G组的分组查询注意力
- 当G=1时(即GQA-1),整个模型只有一组键值头,等同于MQA
- 当G等于头的总数时(即GQA-H),每个查询头都有独立的键值头,等同于MHA
值得一提的是,这种优化主要针对的是自回归(Decoder-only)模型。对于自编码(Encoder-only)模型而言,由于其可以并行计算注意力,因此注意力机制并非性能瓶颈。
说实话,平平无奇的trick。。。
Experiments
实验设置
论文所有实验都基于T5.1.1架构实现,适用了具有多头注意力的T5 Large和XXL,以及具有多查询和分组查询注意力的T5 XXL的升级训练版本。
【注】T5模型升级过。T5.1.1是T5的升级版本,T5.1.1的模型架构和T5.1.0是一样的,但是训练方法和数据集不一样。
数据集采用的是摘要数据集、翻译数据集、问答数据集。
【注】T5是编码器-解码器架构,可以看到作者挑选的数据集还是偏语义向的,没有用纯文本生成的数据集。
微调配置:学习率0.001,batch=128,dropout=0.1,训练直到收敛,然后选性能最好的checkpoint,推理用greedy decoding。
【注】感觉学习率设的好大啊。
实验结果
质量指标
性能指标
如图3与图4所示,GQA质量指标和MHA差不多,但是速度快了不少,达到MQA的水平。
消融
论文通过消融实验分析了以下几个关键变量对模型性能的影响:
1)checkpoint转换策略:对比了选取第一个头、平均所有头、随机选择头三种方式
2)继续预训练的数据量占比
3)查询头分组数量G的选择
通过这几个消融实验,实验证明了采用平均池化所有头、数据量占比为5%,分组数量为头的总数,可以得到最好的效果。
【注】还有要注意的一点是,目前解码器推理时大多实现了KV Cache,对降低显存占用的影响不会那么大。但是GQA仍然可以加快推理速度,这对于部署大模型来说是有意义的。