首先通过一个具体的数字例子,详细说明如何使用旋转位置编码(Rotary Position Embedding, RoPE)。
1. 输入数据
假设:
- 输入张量
x
:形状为(batch_size=2, seq_len=3, num_heads=1, head_dim=4)
,值为:
x = [ [ [ 1 , 2 , 3 , 4 ] ] , [ [ 5 , 6 , 7 , 8 ] ] , [ [ 9 , 10 , 11 , 12 ] ] ] x = \begin{bmatrix} [[1, 2, 3, 4]], \\ [[5, 6, 7, 8]], \\ [[9, 10, 11, 12]] \end{bmatrix} x= [[1,2,3,4]],[[5,6,7,8]],[[9,10,11,12]] - 位置索引
t
:序列长度为3
,位置索引为[0, 1, 2]
。 - 旋转频率
theta
:设为10000.0
。
2. 预计算旋转编码 pos_cis
(1) 计算频率向量
- 向量维度
dim = 4
,因此dim // 2 = 2
。 - 频率向量公式:
θ j = 1 base j / d \theta_j = \frac{1}{\text{base}^{j / d}} θj=basej/d1
其中 j = 0 , 1 j = 0, 1 j=0,1, d = 4 d = 4 d=4,base = 10000.0
。
计算:
θ 0 = 1 1000 0 0 / 4 = 1.0 θ 1 = 1 1000 0 1 / 4 ≈ 0.5623 \theta_0 = \frac{1}{10000^{0 / 4}} = 1.0 \\ \theta_1 = \frac{1}{10000^{1 / 4}} \approx 0.5623 θ0=100000/41=1.0θ1=100001/41≈0.5623
因此,频率向量为:
freqs = [ 1.0 , 0.5623 ] \text{freqs} = [1.0, 0.5623] freqs=[1.0,0.5623]
(2) 计算旋转角度
- 位置索引
t = [0, 1, 2]
。 - 旋转角度公式:
m θ j = t ⋅ θ j m \theta_j = t \cdot \theta_j mθj=t⋅θj
计算:
t ⋅ freqs = [ 0 ⋅ 1.0 0 ⋅ 0.5623 1 ⋅ 1.0 1 ⋅ 0.5623 2 ⋅ 1.0 2 ⋅ 0.5623 ] = [ 0.0 0.0 1.0 0.5623 2.0 1.1246 ] t \cdot \text{freqs} = \begin{bmatrix} 0 \cdot 1.0 & 0 \cdot 0.5623 \\ 1 \cdot 1.0 & 1 \cdot 0.5623 \\ 2 \cdot 1.0 & 2 \cdot 0.5623 \end{bmatrix} = \begin{bmatrix} 0.0 & 0.0 \\ 1.0 & 0.5623 \\ 2.0 & 1.1246 \end{bmatrix} t⋅freqs= 0⋅1.01⋅1.02⋅1.00⋅0.56231⋅0.56232⋅0.5623 = 0.01.02.00.00.56231.1246
(3) 生成复数形式的旋转编码
使用 torch.polar
生成旋转编码:
pos_cis = e i ⋅ m θ j \text{pos\_cis} = e^{i \cdot m \theta_j} pos_cis=ei⋅mθj
计算:
pos_cis = [ e i ⋅ 0.0 e i ⋅ 0.0 e i ⋅ 1.0 e i ⋅ 0.5623 e i ⋅ 2.0 e i ⋅ 1.1246 ] = [ 1.0 + i ⋅ 0.0 1.0 + i ⋅ 0.0 0.5403 + i ⋅ 0.8415 0.8472 + i ⋅ 0.5314 − 0.4161 + i ⋅ 0.9093 0.4284 + i ⋅ 0.9036 ] \text{pos\_cis} = \begin{bmatrix} e^{i \cdot 0.0} & e^{i \cdot 0.0} \\ e^{i \cdot 1.0} & e^{i \cdot 0.5623} \\ e^{i \cdot 2.0} & e^{i \cdot 1.1246} \end{bmatrix} = \begin{bmatrix} 1.0 + i \cdot 0.0 & 1.0 + i \cdot 0.0 \\ 0.5403 + i \cdot 0.8415 & 0.8472 + i \cdot 0.5314 \\ -0.4161 + i \cdot 0.9093 & 0.4284 + i \cdot 0.9036 \end{bmatrix} pos_cis=
ei⋅0.0ei⋅1.0ei⋅2.0ei⋅0.0ei⋅0.5623ei⋅1.1246
=
1.0+i⋅0.00.5403+i⋅0.8415−0.4161+i⋅0.90931.0+i⋅0.00.8472+i⋅0.53140.4284+i⋅0.9036
3. 应用旋转位置编码
(1) 将输入张量 x
转换为复数形式
x
的形状为(2, 3, 1, 4)
,将其转换为复数形式:
x q = [ [ 1 + i ⋅ 2 , 3 + i ⋅ 4 ] , [ 5 + i ⋅ 6 , 7 + i ⋅ 8 ] , [ 9 + i ⋅ 10 , 11 + i ⋅ 12 ] ] x_q = \begin{bmatrix} [1 + i \cdot 2, 3 + i \cdot 4], \\ [5 + i \cdot 6, 7 + i \cdot 8], \\ [9 + i \cdot 10, 11 + i \cdot 12] \end{bmatrix} xq= [1+i⋅2,3+i⋅4],[5+i⋅6,7+i⋅8],[9+i⋅10,11+i⋅12]
(2) 调整 pos_cis
的形状
pos_cis
的形状为(3, 2)
,调整为(1, 3, 1, 2)
,以便与x
广播。
(3) 应用旋转编码
将 pos_cis
与 x_q
逐元素相乘:
x q ′ = x q ⋅ pos_cis x_q' = x_q \cdot \text{pos\_cis} xq′=xq⋅pos_cis
计算:
x q ′ = [ ( 1 + i ⋅ 2 ) ⋅ ( 1.0 + i ⋅ 0.0 ) ( 3 + i ⋅ 4 ) ⋅ ( 1.0 + i ⋅ 0.0 ) ( 5 + i ⋅ 6 ) ⋅ ( 0.5403 + i ⋅ 0.8415 ) ( 7 + i ⋅ 8 ) ⋅ ( 0.8472 + i ⋅ 0.5314 ) ( 9 + i ⋅ 10 ) ⋅ ( − 0.4161 + i ⋅ 0.9093 ) ( 11 + i ⋅ 12 ) ⋅ ( 0.4284 + i ⋅ 0.9036 ) ] x_q' = \begin{bmatrix} (1 + i \cdot 2) \cdot (1.0 + i \cdot 0.0) & (3 + i \cdot 4) \cdot (1.0 + i \cdot 0.0) \\ (5 + i \cdot 6) \cdot (0.5403 + i \cdot 0.8415) & (7 + i \cdot 8) \cdot (0.8472 + i \cdot 0.5314) \\ (9 + i \cdot 10) \cdot (-0.4161 + i \cdot 0.9093) & (11 + i \cdot 12) \cdot (0.4284 + i \cdot 0.9036) \end{bmatrix} xq′=
(1+i⋅2)⋅(1.0+i⋅0.0)(5+i⋅6)⋅(0.5403+i⋅0.8415)(9+i⋅10)⋅(−0.4161+i⋅0.9093)(3+i⋅4)⋅(1.0+i⋅0.0)(7+i⋅8)⋅(0.8472+i⋅0.5314)(11+i⋅12)⋅(0.4284+i⋅0.9036)
逐元素计算结果:
x q ′ = [ 1.0 + i ⋅ 2.0 3.0 + i ⋅ 4.0 − 2.6248 + i ⋅ 7.3479 − 0.3776 + i ⋅ 10.1706 − 13.5123 + i ⋅ 2.0013 − 6.5604 + i ⋅ 15.1404 ] x_q' = \begin{bmatrix} 1.0 + i \cdot 2.0 & 3.0 + i \cdot 4.0 \\ -2.6248 + i \cdot 7.3479 & -0.3776 + i \cdot 10.1706 \\ -13.5123 + i \cdot 2.0013 & -6.5604 + i \cdot 15.1404 \end{bmatrix} xq′=
1.0+i⋅2.0−2.6248+i⋅7.3479−13.5123+i⋅2.00133.0+i⋅4.0−0.3776+i⋅10.1706−6.5604+i⋅15.1404
(4) 将复数形式转换回实数形式
将 x_q'
转换回实数形式,形状为 (2, 3, 1, 4)
:
x q ′ = [ [ [ 1.0 , 2.0 , 3.0 , 4.0 ] ] , [ [ − 2.6248 , 7.3479 , − 0.3776 , 10.1706 ] ] , [ [ − 13.5123 , 2.0013 , − 6.5604 , 15.1404 ] ] ] x_q' = \begin{bmatrix} [[1.0, 2.0, 3.0, 4.0]], \\ [[-2.6248, 7.3479, -0.3776, 10.1706]], \\ [[-13.5123, 2.0013, -6.5604, 15.1404]] \end{bmatrix} xq′=
[[1.0,2.0,3.0,4.0]],[[−2.6248,7.3479,−0.3776,10.1706]],[[−13.5123,2.0013,−6.5604,15.1404]]
4. 最终结果
应用旋转位置编码后,x
的值变为:
x ′ = [ [ [ 1.0 , 2.0 , 3.0 , 4.0 ] ] , [ [ − 2.6248 , 7.3479 , − 0.3776 , 10.1706 ] ] , [ [ − 13.5123 , 2.0013 , − 6.5604 , 15.1404 ] ] ] x' = \begin{bmatrix} [[1.0, 2.0, 3.0, 4.0]], \\ [[-2.6248, 7.3479, -0.3776, 10.1706]], \\ [[-13.5123, 2.0013, -6.5604, 15.1404]] \end{bmatrix} x′=
[[1.0,2.0,3.0,4.0]],[[−2.6248,7.3479,−0.3776,10.1706]],[[−13.5123,2.0013,−6.5604,15.1404]]
5. 总结
通过这个具体的数字例子,我们展示了旋转位置编码的完整计算过程:
- 预计算旋转编码
pos_cis
。 - 将输入张量
x
转换为复数形式。 - 应用旋转编码,将位置信息融入
x
。 - 将结果转换回实数形式。
旋转位置编码的核心思想是通过复数旋转将位置信息融入查询和键向量中,从而增强模型对序列位置的感知能力。
好的!以下是代码对应的数学公式和原理解释,帮助你更好地理解旋转位置编码(Rotary Position Embedding, RoPE)的实现。
1. 旋转位置编码的数学原理
旋转位置编码的核心思想是通过复数旋转将位置信息融入查询(Q
)和键(K
)向量中。具体来说,对于位置 m
和 n
,查询和键向量会被旋转一个与位置相关的角度,从而在计算注意力分数时引入位置信息。
(1) 旋转公式
对于向量 x
和位置 m
,旋转位置编码的公式为:
RoPE ( x , m ) = x ⋅ e i m θ \text{RoPE}(x, m) = x \cdot e^{i m \theta} RoPE(x,m)=x⋅eimθ
其中:
- x x x 是输入向量。
- m m m 是位置索引。
- θ \theta θ 是旋转角度,由频率向量决定。
(2) 频率向量
频率向量 θ j \theta_j θj 的计算公式为:
θ j = 1 base j / d \theta_j = \frac{1}{\text{base}^{j / d}} θj=basej/d1
其中:
- j j j 是维度索引。
- d d d 是向量的维度。
- base \text{base} base 是一个常数(代码中的
theta
)。
2. 代码的数学公式
以下是代码中每一步对应的数学公式。
(1) precompute_pos_cis
函数
-
计算频率向量 θ j \theta_j θj:
θ j = 1 base j / d \theta_j = \frac{1}{\text{base}^{j / d}} θj=basej/d1
代码实现:freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
-
计算旋转角度 m θ j m \theta_j mθj:
m θ j = t ⋅ θ j m \theta_j = t \cdot \theta_j mθj=t⋅θj
代码实现:freqs = torch.outer(t, freqs).float()
-
生成复数形式的旋转编码:
pos_cis = e i m θ j \text{pos\_cis} = e^{i m \theta_j} pos_cis=eimθj
代码实现:pos_cis = torch.polar(torch.ones_like(freqs), freqs)
(2) apply_rotary_emb
函数
-
将查询和键向量转换为复数形式:
x q = x q + i x q ′ , x k = x k + i x k ′ x_q = x_q + i x_q', \quad x_k = x_k + i x_k' xq=xq+ixq′,xk=xk+ixk′
代码实现:xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
-
应用旋转位置编码:
x q ′ = x q ⋅ e i m θ j , x k ′ = x k ⋅ e i m θ j x_q' = x_q \cdot e^{i m \theta_j}, \quad x_k' = x_k \cdot e^{i m \theta_j} xq′=xq⋅eimθj,xk′=xk⋅eimθj
代码实现:xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
3. 公式与代码的对应关系
数学公式 | 代码实现 |
---|---|
θ j = 1 base j / d \theta_j = \frac{1}{\text{base}^{j / d}} θj=basej/d1 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
m θ j = t ⋅ θ j m \theta_j = t \cdot \theta_j mθj=t⋅θj | freqs = torch.outer(t, freqs).float() |
pos_cis = e i m θ j \text{pos\_cis} = e^{i m \theta_j} pos_cis=eimθj | pos_cis = torch.polar(torch.ones_like(freqs), freqs) |
x q = x q + i x q ′ x_q = x_q + i x_q' xq=xq+ixq′ | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
x q ′ = x q ⋅ e i m θ j x_q' = x_q \cdot e^{i m \theta_j} xq′=xq⋅eimθj | xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) |
4. 示例
假设:
- 输入向量 x q = [ 1 , 2 , 3 , 4 ] x_q = [1, 2, 3, 4] xq=[1,2,3,4]。
- 位置索引 m = 1 m = 1 m=1。
- 频率向量 θ j = [ 0.1 , 0.2 ] \theta_j = [0.1, 0.2] θj=[0.1,0.2]。
(1) 计算旋转角度
m θ j = [ 1 ⋅ 0.1 , 1 ⋅ 0.2 ] = [ 0.1 , 0.2 ] m \theta_j = [1 \cdot 0.1, 1 \cdot 0.2] = [0.1, 0.2] mθj=[1⋅0.1,1⋅0.2]=[0.1,0.2]
(2) 生成复数形式的旋转编码
pos_cis = e i [ 0.1 , 0.2 ] = [ cos ( 0.1 ) + i sin ( 0.1 ) , cos ( 0.2 ) + i sin ( 0.2 ) ] \text{pos\_cis} = e^{i [0.1, 0.2]} = [\cos(0.1) + i \sin(0.1), \cos(0.2) + i \sin(0.2)] pos_cis=ei[0.1,0.2]=[cos(0.1)+isin(0.1),cos(0.2)+isin(0.2)]
(3) 应用旋转位置编码
x q ′ = [ 1 + i 2 , 3 + i 4 ] ⋅ [ cos ( 0.1 ) + i sin ( 0.1 ) , cos ( 0.2 ) + i sin ( 0.2 ) ] x_q' = [1 + i 2, 3 + i 4] \cdot [\cos(0.1) + i \sin(0.1), \cos(0.2) + i \sin(0.2)] xq′=[1+i2,3+i4]⋅[cos(0.1)+isin(0.1),cos(0.2)+isin(0.2)]
5. 总结
通过数学公式和代码的对应关系,可以看出旋转位置编码的核心是通过复数旋转将位置信息融入查询和键向量中。具体步骤包括:
- 计算频率向量 θ j \theta_j θj。
- 计算旋转角度 m θ j m \theta_j mθj。
- 生成复数形式的旋转编码 e i m θ j e^{i m \theta_j} eimθj。
- 将旋转编码应用到查询和键向量上。
这种方法能够在不增加额外参数的情况下,显著提升模型对序列位置的感知能力。