Pytorch常用的函数(一)torch.squeeze()和torch.unsqueeze()、torch.cat()和torch.stack()函数功能及使用

Pytorch常用的函数

1、torch.squeeze()和torch.unsqueeze()函数功能及使用

1.1 torch.squeeze()

1.1.1 squeeze(1)和squeeze(-1)

两者的效果一样,都是给张量降维,但是如果不是n*1的这种2维张量的话,调用这个函数一点效果没有。

a = torch.tensor(
    [
        [1],
        [2],
        [3]
    ]
)

b = a.squeeze(1)
c = a.squeeze(-1)
print(a.shape)
print(b)
print(c)
torch.Size([3, 1])
tensor([1, 2, 3])
tensor([1, 2, 3])

1.1.2 squeeze(0)

当张量是一个1*n维度的张量时,可以调用这个函数。

但是如果不是1*n的这种2维张量的话,调用这个函数一点效果没有。

a = torch.tensor(
    [
        [1, 2, 3]
    ]
)

b = a.squeeze(0)
print(a)
print(b)
tensor([[1, 2, 3]])
tensor([1, 2, 3])

1.1.3 squeeze()函数详解

现在详解了解下这个函数。

torch.squeeze(input, dim=None, out=None) 

squeeze()函数的功能是维度压缩。返回一个tensor(张量),其中 input 中维度大小为1的所有维都已删除。

举个例子:如果 input 的形状为 (A×1×B×C×1×D),那么返回的tensor的形状则为 (A×B×C×D)

a = torch.randn(size=(2, 1, 2, 1, 2))
b = a.squeeze(1)      #表示仅仅把a中第2维进行删除
c = torch.squeeze(a)  #表示把a中维度大小为1的所有维都已删除

print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 1, 2])
torch.Size([2, 2, 2])

当给定 dim 时,那么只在给定的维度(dimension)上进行压缩操作,注意给定的维度大小必须是1,否则不能进行压缩。
举个例子:如果 input 的形状为 (A×1×B),squeeze(input, dim=0)后,返回的tensor不变,因为第0维的大小为A,不是1; 而改成squeeze(input, 1)后,返回的tensor将被压缩为 (A×B)。

a = torch.randn(size=(2, 1, 2, 1, 2))
b = torch.squeeze(a,0)  # 表示把a中第1维删除,但是第1维大小为2,不为1,因此结果删除不掉
c = a.squeeze(0)        # 表示把a中第1维删除,但是第1维大小为2,不为1,因此结果删除不掉

print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 1, 2, 1, 2])
a = torch.randn(size=(2, 1, 2, 1, 2))
b = torch.squeeze(a,1)  # 表示把a中第2维删除,因为第2维大小是1,因此可以删掉
c = a.squeeze(1)        # 表示把a中第2维删除,因为第2维大小是1,因此可以删掉

print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 1, 2])
torch.Size([2, 2, 1, 2])

1.2 torch.unsqueeze()

1.2.1 unsqueeze(1)和unsqueeze(-1)

作用和squeeze(1)和squeeze(-1)相反,用于给张量升维,函数效果如下所示。

如果我就是一n*m的2维的张量,调用这两个函数后,一点效果没有。

a = torch.tensor([1,2,3])
b = a.unsqueeze(1)
c = a.unsqueeze(-1)

print(a)
print(b)
print(c)
tensor([1, 2, 3])
tensor([[1],
        [2],
        [3]])
tensor([[1],
        [2],
        [3]])

1.2.2 unsqueeze(0)

和squeeze(0)作用相反,函数效果如下所示。

a = torch.tensor([1, 2, 3])
b = a.unsqueeze(0)


print(a)
print(b)
tensor([1, 2, 3])
tensor([[1, 2, 3]])

1.2.3 unsqueeze函数详解

unsqueeze()函数起升维的作用,参数dim表示在哪个地方加一个维度。

torch.unsqueeze(input, dim) → Tensor

注意dim范围在:[-input.dim() - 1, input.dim() ]之间,比如输入input是1维,则dim=0时数据为行方向扩,dim=1时为列方向扩,再大错误。

a = torch.tensor([1, 2, 3, 4])
b = torch.unsqueeze(a, 0)  #在第0维(行)扩展,第0维大小为1
c = a.unsqueeze(0)         #在第0维(行)扩展,第0维大小为1

print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([4])
torch.Size([1, 4])
torch.Size([1, 4])
a = torch.tensor([1, 2, 3, 4])
b = torch.unsqueeze(a, 1)  #在第1维(列)扩展,第1维大小为1,注意:-1表示在最后一维扩展,最后一维大小为1
c = a.unsqueeze(1)         #在第1维(列)扩展,第1维大小为1,注意:-1表示在最后一维扩展,最后一维大小为1

print(a.shape)
print(b.shape)
print(c.shape)
torch.Size([4])
torch.Size([4, 1])
torch.Size([4, 1])

2、torch.cat()和torch.stack()函数功能及使用

2.1 torch.cat()

torch.cat(tensors, dim=0, *, out=None) → Tensor

函数将两个张量(tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外,其余维数数值需相同,方能对齐。

torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接。

a = torch.zeros(2,2)
b =  torch.ones(2,2)

c = torch.cat((a,b),dim=0)  # dim=0,对行进行拼接
d = torch.cat((a,b),dim=1)  # dim=1,对列进行拼接
print(a)
print(b)
print(c)
print(d)
tensor([[0., 0.],
        [0., 0.]])
tensor([[1., 1.],
        [1., 1.]])
        
tensor([[0., 0.],
        [0., 0.],
        [1., 1.],
        [1., 1.]])
        
tensor([[0., 0., 1., 1.],
        [0., 0., 1., 1.]])

三维张量:dim=0,对通道进行拼接

a = torch.zeros(2,3,4)
b =  torch.ones(1,3,4)
c = torch.cat((a,b),dim=0)  # 除拼接维数dim数值可不同外,其余维数数值需相同
a,b,c
(tensor([[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]),
          
 tensor([[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]),
          
  tensor([[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],
 
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]))         

三维张量:dim=1,对行进行拼接

a = torch.zeros(2,1,4)
b =  torch.ones(2,2,4)
c = torch.cat((a,b),dim=1) # 除拼接维数dim数值可不同外,其余维数数值需相同
a,b,c
(tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]]),
         
 tensor([[[1., 1., 1., 1.],
          [1., 1., 1., 1.]],
 
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.]]]),
  
 tensor([[[0., 0., 0., 0.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
 
         [[0., 0., 0., 0.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]))

三维张量:dim=2,对列进行拼接

a = torch.zeros(2,2,3)
b =  torch.ones(2,2,1)
c = torch.cat((a,b),dim=2)  # 除拼接维数dim数值可不同外,其余维数数值需相同
a,b,c
(tensor([[[0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.]]]),
          
 tensor([[[1.],
          [1.]],
 
         [[1.],
          [1.]]]),
          
 tensor([[[0., 0., 0., 1.],
          [0., 0., 0., 1.]],
 
         [[0., 0., 0., 1.],
          [0., 0., 0., 1.]]]))

2.2 torch.stack()

torch.stack(tensors, dim=0, *, out=None) → Tensor
  • tensors :为一系列输入张量,类型为turple和List
  • dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接
  • 返回值:输出新增维度后的张量

沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。

也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。

二维张量:dim=0,通道维度上进行组合

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], dim=0)  #在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维)


a,b,c
(tensor([[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]),
              
 tensor([[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]),
           
 tensor([[[ 1,  2,  3],
          [ 4,  5,  6],
          [ 7,  8,  9]],
 
         [[11, 22, 33],
          [44, 55, 66],
          [77, 88, 99]]]))

此时,我们用cat()进行连接

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.cat([a, b], dim=0)  #在第0维进行连接,相当对于行进行组合(输入张量为两维,输出张量为二维)


a,b,c
(tensor([[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]),
         
 tensor([[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]),
         
tensor([[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]))

二维张量:dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 1)  #在第1维进行连接,相当于对相应通道中每个行进行组合

a,b,c
(tensor([[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]),
 tensor([[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]),
 tensor([[[ 1,  2,  3],
          [11, 22, 33]],
          
         [[ 4,  5,  6],
          [44, 55, 66]],
 
         [[ 7,  8,  9],
          [77, 88, 99]]]))

二维张量:dim=2:表示在第2维进行连接,相当于对相应通道中每个列进行组合

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 2)  #在第2维进行连接,相当于对相应通道中每个行进行组合

a,b,c
(tensor([[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]),
 tensor([[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]),
  tensor([[[ 1, 11],
          [ 2, 22],
          [ 3, 33]],
 
         [[ 4, 44],
          [ 5, 55],
          [ 6, 66]],
 
         [[ 7, 77],
          [ 8, 88],
          [ 9, 99]]]))

猜你喜欢

转载自blog.csdn.net/qq_44665283/article/details/131053369