深度学习框架_PyTorch_torch.stack()函数和torch.cat()函数

torch.stcak()函数对多个张量在维度上进行叠加。
其中参数dim代表不同的维度。
具体如下代码所示:

>>> a = torch.ones(3,3)
>>> a
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>> b = torch.ones(3,3) + 1
>>> b
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])
>>> c = torch.ones(3,3) + 2
>>> c
tensor([[3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]])
# 当dim=0时,不同的张量直接叠加
>>> d = torch.stack((a,b,c),0)
>>> d
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.]],

        [[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]]])
#当dim=1时,不同的张量在第一维度组合,并叠加
>>> d = torch.stack((a,b,c),1)
>>> d
tensor([[[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.]],

        [[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.]],

        [[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.]]])
#当dim=2时,不同的张量在第二维度组合,并叠加
>>> d = torch.stack((a,b,c),2)
>>> d
tensor([[[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]],

        [[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]],

        [[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]]])
# 当dim=-1时,就是在最后一个维度组合,并叠加
>>> d = torch.stack((a,b,c),-1)
>>> d
tensor([[[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]],

        [[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]],

        [[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]]])

torch.cat()函数对多个张量进行某一维度的拼接,拼接后的总维度数不变。
其中参数dim代表了不同的维度。

解析来我们从代码中进行分析:

# 当dim=0时,从第一维度进行拼接
>>> d = torch.cat((a,b,c),0)
>>> d
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]])
# 当dim=1时,从第二维度进行拼接
>>> d = torch.cat((a,b,c),1)
>>> d
tensor([[1., 1., 1., 2., 2., 2., 3., 3., 3.],
        [1., 1., 1., 2., 2., 2., 3., 3., 3.],
        [1., 1., 1., 2., 2., 2., 3., 3., 3.]])
# 当dim=-1时,从对后一个维度进行拼接
>>> d = torch.cat((a,b,c),-1)
>>> d
tensor([[1., 1., 1., 2., 2., 2., 3., 3., 3.],
        [1., 1., 1., 2., 2., 2., 3., 3., 3.],
        [1., 1., 1., 2., 2., 2., 3., 3., 3.]])

注意:torch.cat()函数存在特例。若用torch.unsqueeze()函数对上述的a,b张量进行升维,在用torch.cat()函数可进行通道数叠加操作。

如下面的代码所示:

>>> a = a.unsqueeze(0)
>>> a
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

>>> a.size()
torch.Size([1, 3, 3])

>>> b = b.unsqueeze(0)
>>> b
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

>>> b.size()
torch.Size([1, 3, 3])

# 在没有指定dim时默认为通道数叠加
>>> c = torch.cat((a,b))
>>> c
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

# dim=0时是第一维拼接,即通道数叠加
>>> c = torch.cat((a,b),0)
>>> c
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])

>>> c.size()
torch.Size([2, 3, 3])

# dim=1时第二维拼接
>>> c = torch.cat((a,b),1)
>>> c
tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])
发布了156 篇原创文章 · 获赞 48 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/Rocky6688/article/details/104354699