【Python基础查漏补缺】常用数组或矩阵的维度操作(切片、压缩展开、转置、排列、展平)

1 切片操作

大体形式:

x[:, :, :, :]

这个操作是最基本,也是初学时最难理解的一个操作。不管是在np.array数组中,还是在torch.tensor中,都可以用这种通用方式去切片出我们需要的矩阵。
简单切片操作:

x[idx_start:idx_end:stride]  #x[起点:终点:步长]

带逗号的切片操作:

x[idx_start:idx_end, idx_start:idx_end:stride]

逗号的作用是区分维度(记住这个,基本就理解这类语法了),如果步长取-1,则代表从后往前取,但是要注意一点,逗号前面的不能限定步长。另外,补充一个常识:遇到这种[m:n]语法时,牢记左闭右开,即左侧m能取到,右侧n取不到,只能取到n-1(联想range(0, n)这种语法同理,取值范围是0~n-1)。
示例:
取第三个维度的5-7维,取第四个维度的0, 1维

x = torch.zeros(8, 1, 16, 3)
y = x[:, :, 5:8, :-1]
print(y.shape)
torch.Size([8, 1, 3, 2])

小心连续切片的问题!
还是上面举的例子,取第三个维度的5-7维,取第四个维度的0, 1维,如果按矩阵取值的方式(用多个中括号去分别定位每个维度)去写会是下面这样:

x = torch.zeros(8, 1, 16, 3)
y = x[:][:][5:8][:-1]
print(y.shape)
torch.Size([2, 1, 16, 3])

这个结果显然是错误的!因为把第一个维度中的8给切成了2,别的都没变化。
所以这里理解连续切片的概念,我们debug分析上面的代码:

  • x[:]得到的结果是:torch.Size([8, 1, 16, 3]),显然没变化
  • x[:][:]得到的结果是:torch.Size([8, 1, 16, 3]),仍然没变化
  • x[:][:][5:8]的结果是:torch.Size([3, 1, 16, 3]),第一个维度8->3
  • x[:][:][5:8][:-1]的结果是:torch.Size([2, 1, 16, 3]),第一个维度3->2

由此可见,这种连续切片的方式(理解为连续多个中括号) 并没有分别去改变每个维度,一定是 在上一步切片的结果上,进行一次新的切片,和x[3][1][2][4]矩阵取值的思路完全不一样!
如果要实现对每个维度的分别切片,还得用上面那个例子中的写法:

x[:, :, 5:8, :-1]

上面我只提及了常见的用法和坑点,详细的教程和例子可以参考:

2 压缩和展开操作

大体形式:

x.squeeze(dim=n)
x.unsqueeze(dim=n)

很多时候,我们都需要将矩阵展开维度或压缩维度后进行矩阵的运算。squeeze()函数为压缩操作,将某一个维度值为1的维度进行删减,或将多个维度值为1的维度进行删减;unsqueeze()函数为展开操作,将某个维度补上维度值为1的维度。
示例:
在256和32对应维度之间补上一个维度1

x = torch.zeros(8, 256, 32, 64)
y = x.unsqueeze(dim=2)
print(y.shape)
torch.Size([8, 256, 1, 32, 64])

删掉前面那个维度1

x = torch.zeros(8, 1, 256, 1)
y = x.squeeze(dim=1)
print(y.shape)
torch.Size([8, 256, 1])

同时删掉多个(所有)维度1

x = torch.zeros(8, 1, 256, 1)
y = x.squeeze()  # 留空就是删掉所有维度1
print(y.shape)
torch.Size([8, 256])

3 转置操作

大体形式:

x.transpose(m, n)

transpose()代表转置,即线性代数中将两个位置进行交换,在高维矩阵中同理。
示例:
将256与32对应维度进行交换

x = torch.zeros(8, 256, 32, 64)
y = x.transpose(1, 2)
print(y.shape)
torch.Size([8, 32, 256, 64])

4 排列操作

大体形式:

x.permute(c, b, a, d)

permute()函数是transpose()的更一般形式,因为它可以同时处理多个位置的顺序变换。这个操作非常好理解,假如原矩阵是(8, 256, 32, 64)的维度,那么0位置对应8,1位置对应256,2位置对应32,3位置对应32,如果我们想让矩阵变为(256, 8, 32, 64),那就需要交换位置0和1,于是语法如下:

x.permute(1, 0, 2, 3)

示例:
8对应维度不动,256对应维度移动到末尾,32对应维度移动到第二个位置,64对应维度移动到倒数第二个位置

x = torch.zeros(8, 256, 32, 64)
y = x.permute(0, 2, 3, 1)
print(y.shape)
torch.Size([8, 32, 64, 256])

不改变各个维度的位置

x = torch.zeros(8, 256, 32, 64)
y = x.permute(0, 1, 2, 3)  # 注意:这个就是表示原位置
print(y.shape)
torch.Size([8, 256, 32, 64])

5 展平操作

大体形式:

x.flattten(m, n)

在神经网络搭建中,时常会在全连接之前将矩阵的某两个维度合并在一起(如H x W),这个操作叫做展平(flatten)。 注意,flatten()函数中的两个位置索引要求m<n,否则会报错!
示例:
展平最后两维

x = torch.zeros(8, 1, 16, 3)
y = x.flatten(-2, -1)
print(y.shape)
torch.Size([8, 1, 48])

猜你喜欢

转载自blog.csdn.net/qq_16763983/article/details/126551115
今日推荐