torch.stack()的使用

train_X = torch.randn(1,2,3) # tensor already
train_Y = torch.randn(1,2,3) # tensor already
print(train_X)
print(train_Y)

features = []
features.append(train_X)
features.append(train_Y)
print(features)

st = torch.stack(features, dim=1)  #torch.Size([1, 2, 2, 3])
print(st)

st = torch.stack(features, dim=2)  #torch.Size([1, 2, 2, 3])
print(st)

st = torch.stack(features, dim=3)  #torch.Size([1, 2, 3, 2])
print(st)

猜你喜欢

转载自blog.csdn.net/tony2278/article/details/105197774