深度学习框架_PyTorch_torch.squeeze()函数和torch.unsqueeze()函数的用法

torch.squeeze()函数的用法主要是对数据的维度进行压缩,去掉维数为1的维度。

torch.squeeze(x)是去掉x中所有维数为1的维度;x.squeeze(n)是去掉x中指定的维数为1的维度。

接下来我们在具体代码中了解:

>>> c
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])

>>> torch.squeeze(c)
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])

>>> c =c.squeeze(0)
>>> c
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])

torch.unsqueeze()函数主要对数据进行扩充。给指定位置加上维数为1的维度。

x.squeeze(n)就是在x中指定位置n加上维数为1的维度。

我们继续看代码:

>>> c
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])

>>> c = c.unsqueeze(0)
>>> c
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])

>>> c = c.unsqueeze(1)
>>> c
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
>>> c.size()
torch.Size([1, 1, 6, 3])

>>> c = c.unsqueeze(3)
>>> c
tensor([[[[[1., 1., 1.]],

          [[1., 1., 1.]],

          [[1., 1., 1.]],

          [[1., 1., 1.]],

          [[1., 1., 1.]],

          [[1., 1., 1.]]]]])
>>> c.size()
torch.Size([1, 1, 6, 1, 3])
发布了156 篇原创文章 · 获赞 48 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/Rocky6688/article/details/104359124