torch.cat 和 torch.stack

torch.cat 和 torch.stack看起来相似但是性质还是不同的

使用python中的list列表收录tensor时,然后将list列表转化成tensor时,会报错。这个时候就要使用torch.stack进行堆叠,转化成tensor。

  • torch.cat()

torch.cat(tensors,dim=0,out=None)→ Tensor
torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变

import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.cat((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([4, 3]))
可以看到c和a、b一样都是二维的。
  • torch.stack()

torch.stack(tensors,dim=0,out=None)→ Tensor
torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维

import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 2, 3]))
可以看到c是三维的,比a、b多了一维。

猜你喜欢

转载自blog.csdn.net/weixin_37707670/article/details/119644333