大模型 | 一篇搞明白上下文长度扩展:从RoPE到YARN

目前的较为流行的支持长序列的模型比如Qwen2.5、DeepSeek-R1等模型都在训练中引入了YARN[1]来做上下文扩展。本文主要就是按照原论文的结构对YARN进行介绍,希望本文能够帮助读者更好的理解旋转位置编码以及YARN的原理。

YARN的全称是Yet Another RoPE Extention,顾名思义,YARN是对RoPE的一种扩展,应用YARN后只需在少量的长文本数据上微调即可实现模型上下文长度的扩展。

了解YARN的前提是首先得了解RoPE是怎么做到将位置信息注入LLM中的。

一、RoPE

1.1 RoPE简介

img
img

1.2 RoPE的优点

绝对位置:RoPE本身的形式向输入向量添加了绝对位置信息。

相对位置:应用RoPE后的向量在进行内积运算时相当于考虑了输入向量的相对位置信息。考虑位置 m 处的向量 q 与位置n处的向量 k 二者应用RoPE后的内积为

img

  1. 远程衰减:给定向量维度 D 特定 b bb 的取值能够使得应用RoPE后的向量内积具有远程衰减性质,即:对于两个固定的向量,他们之间的相对距离越远,内积值越小。

  2. 形式简单,与当前attention机制的结合非常自然。

1.3 RoPE的参数化

img

1.4 RoPE中的隐含的高低频分量的概念

img

二、上下文长度拓展

上面我们简单的介绍了一下RoPE,下面让我们来考虑使用RoPE训练完的模型如何进行上下文长度拓展。本节中,我们将介绍对RoPE的一些改进方案,并逐步过度到YARN。

img

2.1 内插与外推

我们考虑一下一维的情况。假设原先我们的模型支持四个位置,分别用[0,4)区间中的0、1、2、3来表示,现在我们希望模型能够支持8个不同的位置,我们应该怎么办呢?有两种直观的方法可以做到这一点:

  1. 保持相邻点的间隔为1不变,将取值范围从[0,4)直接将取值范围扩展至[0,8)即可,这就是所谓的外推(extrapolation)
  2. 维持原先的区间不变,从原区间取更多的点来表示新的位置,此时我们的取值范围维持[0,4),但相邻点之间的间隔从1缩小到了0.5,这就是所谓的内插(interpolation)

考虑完1维的简单情形,下面让我们由易到难看看几种不同的对RoPE进行改动以进行上下文长度拓展方法,每种方法都是对前一种方案的改进,最终我们将会得到YARN。

2.2 Position Interpolation (PI)

2.2.1 方法介绍

img

2.2.2 存在的问题

PI的思想非常直观,就是位置编号上的完全内插,实现起来非常简单。但PI存在一定的问题。 PI的旋转角度计算公式可以重写为
img

Deep neural networks have trouble learning high frequency information if the input dimension is low and the corresponding embeddings lack high frequency components.

Stretching the RoPE embeddings indiscriminately results in the loss of important high frequency details which the network needs in order to resolve tokens that are both very similar and very close together

翻译一下就是:根据NTK理论,当模型的输入特征的维度很低而对应的embedding又缺乏高频分量时(这恰好是位置编码所面临的情况,位置信息是一维的,而我们在将他变成高维的embedding信息注入到模型中),模型很难学到高频信息[2]。回忆一下我们在1.4节中关于频率分量的讨论,PI的做法相当于把所有的分量的频率都统一降低成了原先的 L/L′/L’,这样会导致模型丢失原先高频分量中的细节信息,使得模型难以区分相对位置接近而本身语义又相似的token。

简而言之,PI存在的问题是:根据NTK理论,输入特征中高频分量的分布对模型十分重要,而PI的做法导致输入中的高频分量的分布发生了较大的变化,对模型的性能有损害

2.3 NTK-aware Interpolation

针对上述对于PI的缺陷分析,研究人员提出了NTK-aware Interpolation。

2.3.1 方法介绍

既然NTK理论告诉我们,模型对高频分量的分布敏感,那么我们应该尽量保持高频分量的分布不变,而在低频分量的部分做插值,也就是高频外推,低频内插

img
在这里插入图片描述

【多种曲线均符合“高频外推,低频内插的原则”】

img

2.3.2 存在的问题

NTK-aware看上去很优美,考虑到了频率与内外插程度应当是相关联的,并用一个拟合出的指数函数来将分组dd(也就是频率)与内外插的程度联系起来,但这样的建模足够精细吗?答案是否定的。

YARN的作者意识到,在RoPE的训练过程中存在一些足够低频的分量,这些低频分量对应的波长λd λd 长到即使是训练过程中最长的序列也没有办法让这些分量经过一个完整周期,对于这些分量,我们显然不应该对他们进行任何的外推。否则可能会引入一些从未见过的旋转角度,这些旋转角度对应的正余弦值在训练过程中模型也从未见过,会导致模型的效果下降。

在这里插入图片描述
【对于足够低频的分量,外推会引入分布外的旋转矩阵】

让我们从数学上推导一下哪些分量出现了过度的外推,不感兴趣的可以直接跳过,看下一节:

1. 哪些分量出现了外推
所谓外推就是值域出现了扩大,即,最大旋转角度在扩展后超过了原先的最大旋转角度,即

img

代入具体的函数形式,我们可以得到

img

2. 哪些分量不应该外推
如前所述,对于波长大于原最大序列长度的那些低频分量,我们不应该对他们进行外推,即:

img

2.4 NTK-by-parts Interpolation

在上一节中我们分析了NTK-aware插值方法存在的问题,在某些极端低频的分量上进行了过度的外推,导致模型性能下降。

同理,我们可以考虑在哪些分量我们或许可以做完全的外推。根据NTK理论,模型对高频分量的分布敏感,因此对足够高频的分量应该尽量保持其频率不变,需要完全的外推。NTK-by-parts就是基于这样的思想提出的,对于足够低频的分量做完全的内插,对足够高频的分量做完全的外推,而对中间部分的分量,既外推也内插。

img

2.5 YARN

终于,在讲了3种层层递进的对RoPE进行长度拓展的方法后来到了我们的终极方法YARN。

如果前面几种方法你基本都理解了,那么看到这里的你可以松一口气了。因为YARN的本体就是NTK-by-parts,只是YARN在NTK-by-parts的基础上额外增加了一个attention-scaling的机制。用原文作者给出的直观表示:

YARN = NTK-by-parts + attention-scaling

所谓的attention-scaling就是在计算attention的环节,YARN会额外对attention score(也就是query和key向量的内积)除以一个常数td。形式上,了解对比学习或者LLM生成过程的同学应该挺熟悉的,这相当于对attention的计算过程加了个温度系数。

img

三、Coding环节

Talk is cheap, show me the code.

讲了这么多理论,还是得动手实践一下来检验自己的理解。目前使用YARN的知名度最高的开源模型当属DeepSeek-R1了,你能完成以下的编程练习来实现一个YARN使得你的实现与DeepSeek-R1的结果一致吗? 有以下两点需要注意:

DeepSeek-R1中YARN的实现在中间频率分量部分的插值处理与YARN的原文不大一样,需要你读一下DeepSeek-R1的源码
YARN的计算并不涉及到模型权重,因此你并不需要clone完整的权重,只需用以下命令clone不带权重的部分即可

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-R1

下面是题目,你需要补全DeepseekV3YarnRotaryEmbeddingReproduce类中关于cos与sin cache的计算逻辑

import math
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoConfig
class DeepseekV3YarnRotaryEmbeddingReproduce(nn.Module):
    def __init__(self, config):
        # Parameter names are consistent with the formula in this post. Change them if you want.
        # RoPE params
        self.b = config.rope_theta
        self.L = config.rope_scaling["original_max_position_embeddings"]
        self.D = config.hidden_size
        # YARN Params
        self.s = config.rope_scaling["factor"]
        self.L_sharp = self.L * self.s
        self.alpha = config.rope_scaling["beta_slow"]
        self.beta = config.rope_scaling["beta_fast"]
        # Cos & Sin Cache calculation
        cos_cached = torch.zeros([self.L_sharp, self.D], dtype=torch.float32)
        sin_cached = torch.zeros([self.L_sharp, self.D], dtype=torch.float32)
        # TODO: add your codes here
        self.register_buffer("cos_cached", cos_cache)
        self.register_buffer("sin_cached", sin_cache)
    def forward(self, x=None, seq_len=None):
        cos = self.cos_cache[:seq_len]
        sin = self.sin_cache[:seq_len]
        return cos, sin
if __name__ == "__main__":
    config = AutoConfig.from_pretrained(
        "YOUR_DeepSeek_R1_path",  # 替换成本地的DeepSeek-R1代码路径
        trust_remote_code=True,
        first_k_dense_replace=1,
        num_hidden_layers=1
    )
    ds_v3_yarn = AutoModelForCausalLM.from_config(
        config,
        trust_remote_code=True
    ).model.layers[0].self_attn.rotary_emb
    yarn_rep = DeepseekV3YarnRotaryEmbeddingReproduce(config)
    cos_ref, sin_ref = ds_v3_yarn.cos_cached, ds_v3_yarn.sin_cached
    cos_rep, sin_rep = yarn_rep.cos_cached, yarn_rep.sin_cached
    assert torch.allclose(cos_ref, cos_rep, atol=1e-6), f"Cosine not equal."
    assert torch.allclose(sin_ref, sin_rep, atol=1e-6), f"Sine not equal."
    print("Congratulations!")

四、如何系统学习掌握AI大模型?

AI大模型作为人工智能领域的重要技术突破,正成为推动各行各业创新和转型的关键力量。抓住AI大模型的风口,掌握AI大模型的知识和技能将变得越来越重要。

学习AI大模型是一个系统的过程,需要从基础开始,逐步深入到更高级的技术。

这里给大家精心整理了一份全面的AI大模型学习资源,包括:AI大模型全套学习路线图(从入门到实战)、精品AI大模型学习书籍手册、视频教程、实战学习、面试题等,资料免费分享

在这里插入图片描述

1. 成长路线图&学习规划

要学习一门新的技术,作为新手一定要先学习成长路线图方向不对,努力白费

这里,我们为新手和想要进一步提升的专业人士准备了一份详细的学习成长路线图和规划。可以说是最科学最系统的学习成长路线。

在这里插入图片描述

2. 大模型经典PDF书籍

书籍和学习文档资料是学习大模型过程中必不可少的,我们精选了一系列深入探讨大模型技术的书籍和学习文档,它们由领域内的顶尖专家撰写,内容全面、深入、详尽,为你学习大模型提供坚实的理论基础(书籍含电子版PDF)

在这里插入图片描述

3. 大模型视频教程

对于很多自学或者没有基础的同学来说,书籍这些纯文字类的学习教材会觉得比较晦涩难以理解,因此,我们提供了丰富的大模型视频教程,以动态、形象的方式展示技术概念,帮助你更快、更轻松地掌握核心知识

在这里插入图片描述

4. 2024行业报告

行业分析主要包括对不同行业的现状、趋势、问题、机会等进行系统地调研和评估,以了解哪些行业更适合引入大模型的技术和应用,以及在哪些方面可以发挥大模型的优势。

在这里插入图片描述

5. 大模型项目实战

学以致用 ,当你的理论知识积累到一定程度,就需要通过项目实战,在实际操作中检验和巩固你所学到的知识,同时为你找工作和职业发展打下坚实的基础。

在这里插入图片描述

6. 大模型面试题

面试不仅是技术的较量,更需要充分的准备。

在你已经掌握了大模型技术之后,就需要开始准备面试,我们将提供精心整理的大模型面试题库,涵盖当前面试中可能遇到的各种技术问题,让你在面试中游刃有余。

在这里插入图片描述

全套的AI大模型学习资源已经整理打包,有需要的小伙伴可以微信扫描下方CSDN官方认证二维码,免费领取【保证100%免费

在这里插入图片描述