einops.rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)
功能:重新划分张量维度,可以实现数组的转置、拆分、合并等操作。
输入:
tensor
:需要调整维度的张量数据;pattern
:调整规则;axes_lengths
:附加的尺寸规格;
注意:
- 如果需要将某一维度拆分成多个维度,需要额外指定一些附加的尺寸规格变量,同时拆分或者合并维度时,注意变量顺序;
代码案例
拆分
import torch
from einops import rearrange
data = torch.range(1, 10)
data1 = rearrange(data, '(a b) -> a b', a=2, b=5)
data2 = rearrange(data, '(b a) -> a b', a=2, b=5)
print(data1)
print(data2)
输出
# (a b) -> a b
tensor([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.]])
# (b a) -> a b
tensor([[ 1., 3., 5., 7., 9.],
[ 2., 4., 6., 8., 10.]])
注意:(a b) -> a b时,相当于直接按顺序拆分,每b个为1组,一共分出a组来,b看成每组的特征长度,a看成组数;(b a) -> a b时,相当于先把数据划分成(b, a)的,之后再做一次转置,即:
print(data.reshape(5, 2).transpose(-1, -2))
# 输出
tensor([[ 1., 3., 5., 7., 9.],
[ 2., 4., 6., 8., 10.]])
合并
import torch
from einops import rearrange
data = torch.range(1, 10).reshape(2, 5)
data1 = rearrange(data, 'a b -> (a b)')
data2 = rearrange(data, 'a b -> (b a)')
print(data)
print(data1)
print(data2)
输出
tensor([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.]])
# a b -> (a b)
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
# a b -> (b a)
tensor([ 1., 6., 2., 7., 3., 8., 4., 9., 5., 10.])
注意:(a b) -> a b时,相当于直接按顺序合并,a个组的特征,按顺序串联合并;a b -> (b a)时,相当于先把数组做转置,之后再合并,即:
print(data.transpose(-1, -2).reshape(-1))
tensor([ 1., 6., 2., 7., 3., 8., 4., 9., 5., 10.])