torch.cat()

官网:https://pytorch.org/docs/stable/torch.html#torch.cat

torch.cat 进行tensor结构的拼接操作。

torch.cat(tensorsdim=0out=None) → Tensor

tensors是input,进行连接的tensor必须具有相同的结构。dim拼接的纬度,0按照行(列不变行增加),1按列(行不变列增加)。

官方示例:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

注意下面两个不等同:维度不同了。 

torch.cat((x,x,x),0)
torch.cat((torch.cat((x,x),0),x),0)
发布了56 篇原创文章 · 获赞 29 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/foneone/article/details/103853713