Python常用库:rearrange函数——转换数组维度

Python常用库:rearrange函数——转换数组维度

einops.rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)

功能:重新划分张量维度,可以实现数组的转置、拆分、合并等操作。

输入:

  • tensor:需要调整维度的张量数据;
  • pattern:调整规则;
  • axes_lengths:附加的尺寸规格;

注意:

  • 如果需要将某一维度拆分成多个维度,需要额外指定一些附加的尺寸规格变量,同时拆分或者合并维度时,注意变量顺序;

代码案例

拆分

import torch
from einops import rearrange

data = torch.range(1, 10)
data1 = rearrange(data, '(a b) -> a b', a=2, b=5)
data2 = rearrange(data, '(b a) -> a b', a=2, b=5)
print(data1)
print(data2)

输出

# (a b) -> a b
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])
# (b a) -> a b
tensor([[ 1.,  3.,  5.,  7.,  9.],
        [ 2.,  4.,  6.,  8., 10.]])

注意:(a b) -> a b时,相当于直接按顺序拆分,每b个为1组,一共分出a组来,b看成每组的特征长度,a看成组数;(b a) -> a b时,相当于先把数据划分成(b, a)的,之后再做一次转置,即:

print(data.reshape(5, 2).transpose(-1, -2))

# 输出
tensor([[ 1.,  3.,  5.,  7.,  9.],
        [ 2.,  4.,  6.,  8., 10.]])

合并

import torch
from einops import rearrange

data = torch.range(1, 10).reshape(2, 5)
data1 = rearrange(data, 'a b -> (a b)')
data2 = rearrange(data, 'a b -> (b a)')
print(data)
print(data1)
print(data2)

输出

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])
# a b -> (a b)
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
# a b -> (b a)
tensor([ 1.,  6.,  2.,  7.,  3.,  8.,  4.,  9.,  5., 10.])

注意:(a b) -> a b时,相当于直接按顺序合并,a个组的特征,按顺序串联合并;a b -> (b a)时,相当于先把数组做转置,之后再合并,即:

print(data.transpose(-1, -2).reshape(-1))

tensor([ 1.,  6.,  2.,  7.,  3.,  8.,  4.,  9.,  5., 10.])

官网文档:https://einops.rocks/api/rearrange/

猜你喜欢

转载自blog.csdn.net/qq_50001789/article/details/136158442