torch.cat() 和 torch.stack()

小知识,大挑战!本文正在参与“程序员必备小知识”创作活动。

torch.cat

torch.cat(tensors, dim=0,  * , out=None) → Tensor

将参数的张量连接起来,张量必须是同纬度的或者空的。

  • 参数
    • tensors (sequence of Tensors) – 以元组的形式传入多个张量
    • dim (int , optional) – 张量连接的维数。以二维为例:0是行,1是列... 你连接的张量是几维的参数就可以选几维。比如两个二维张量你就只能选(-2, -1, 0, 1),如果是三维的,你就只能选(-3, -2, -1, 0, 1, 2)
      • image.png 经过我的测试-2和0结果是一样的,-1和1的结果是一样的。我的猜测是-1和0为界限,右边几个左边就几个,从左到右作用一样。
import torch

t1 = torch.rand(2, 3)
t2 = torch.rand(2, 3)
t3 = torch.cat((t1, t2), dim=0)
print(t1)
print(t2)
print(t3)

t4 = torch.cat((t1, t2), dim=1)
print(t4)
复制代码

输出:

tensor([[0.7839, 0.1447, 0.4310],
        [0.9642, 0.5121, 0.3178]])
tensor([[0.7691, 0.2200, 0.8842],
        [0.6078, 0.9669, 0.9191]])
tensor([[0.7839, 0.1447, 0.4310],
        [0.9642, 0.5121, 0.3178],
        [0.7691, 0.2200, 0.8842],
        [0.6078, 0.9669, 0.9191]])
tensor([[0.7839, 0.1447, 0.4310, 0.7691, 0.2200, 0.8842],
        [0.9642, 0.5121, 0.3178, 0.6078, 0.9669, 0.9191]])
复制代码

torch.stack

torch.stack(tensors, dim=0,  * , out=None) → Tensor

将张量链接成一个新张量,按照新的维度

  • 参数
    • tensors (sequence of Tensors) – 需要连接的张量
    • dim (int) – 插入的维度,必须是0和连接之后张量的维度之间。
      • 这个是会增加一维,两个二维连接成三维这种。dim可以选到连接之后的参数。比如二维的是[-3,2],三维是[-4,3]。

这个输出就不搞了,看一串数也没啥好看的。多维的肉眼也不好看。

猜你喜欢

转载自juejin.im/post/7017780753979146253
今日推荐