官网:https://pytorch.org/docs/stable/torch.html#torch.cat
torch.cat 进行tensor结构的拼接操作。
torch.
cat
(tensors, dim=0, out=None) → Tensor
tensors是input,进行连接的tensor必须具有相同的结构。dim拼接的纬度,0按照行(列不变行增加),1按列(行不变列增加)。
官方示例:
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])
注意下面两个不等同:维度不同了。
torch.cat((x,x,x),0)
torch.cat((torch.cat((x,x),0),x),0)