torch.cat函数详解

torch.cat是PyTorch深度学习框架中的一个函数,用于将多个张量沿着指定的维度拼接在一起。具体来说,它可以将多个形状相同的张量按照指定的维度进行拼接,返回一个新的张量。

torch.cat的语法如下:

torch.cat(seq, dim=0, *, out=None) -> Tensor

其中,参数seq是要拼接的张量序列,它们应该具有相同的形状(除了沿着拼接维度的大小),可以是一个Python列表或元组。参数dim表示要沿着哪个维度进行拼接,默认值为0(即第0个维度)。返回值是一个新的张量,与输入张量在拼接维度上的大小之和相同。

例如,假设有两个形状为(3, 4)的张量x1和x2,我们可以使用以下代码将它们沿着第0个维度进行拼接:

import torch

x1 = torch.randn(3, 4)
x2 = torch.randn(3, 4)
y = torch.cat([x1, x2], dim=0)
print(y.shape) # 输出:torch.Size([6, 4])

上述代码将输出一个形状为(6, 4)的新张量y,其中前三行是x1的内容,后三行是x2的内容。

猜你喜欢

转载自blog.csdn.net/qq_45138078/article/details/129860313