【pytorch学习】torch.cat

因为图像识别中网络的Tensor一般为N * C * H * W,所以我们的例子也是用4维数据

维度是从零开始的,即生成一个a = torch.randn(1, 3, 3, 4)

首先我们先从维度1开始拼接数据,即Tensor的维度C(channel)

import torch

a = torch.randn(1, 3, 3, 4) # C 为3
b = torch.randn(1, 2, 3, 4) # C 为2
# d = torch.randn(1, 2, 3, 4)
c = torch.cat((a,b), 1) # 从维度1开始拼接
print(c.size())

结果为

我们改变Tensor  b中维度2的大小即将b = torch.randn(1, 2, 3, 4)改为 b= torch.randn(1, 2, 5, 4)

会报错,报错信息中看出,无效的参数,除了Tensor的维度1的大小可以不同外,其他维度的大小必须相等。因为在我们的a,b中除了维度1大小不同外,维度2的size也不同(一个为3,一个为5)

其实就是需要拼接的数据按哪个维度拼接,哪个维度的大小就可以不同,其他维度必须相同

下面再放一个例子,改变Tensor b中维度0的大小,即 b = torch.randn(1, 2, 3, 4)改为 b= torch.randn(3, 2, 3, 4)

出现了前面类似的错误信息,只是这次是除了维度0外,其他维度必须相同。

如果有上面哪里有错误的地方,望能指出。我也还在学习。

猜你喜欢

转载自blog.csdn.net/six_water/article/details/89203788