torch.stack(list,0)
list 中的每个元素为tensor 中第0维度的每个元素
import torch
a = torch.Tensor([[1, 3, 2], [1, 3, 2]])
b = torch.Tensor([[2, 1, 1], [2, 1, 1]])
c = torch.Tensor([[3, 2, 3], [2, 1, 1]])
my_list = [a, b, c]
print(torch.stack(my_list, 0))
》》》
tensor([[[ 1., 3., 2.],
[ 1., 3., 2.]],
[[ 2., 1., 1.],
[ 2., 1., 1.]],
[[ 3., 2., 3.],
[ 2., 1., 1.]]])
troch.stack(list,1) list中每个元素的第0维元素-第n维元素各成一组
print(torch.stack(my_list, 1))
》》》
tensor([[[ 1., 3., 2.],
[ 2., 1., 1.],
[ 3., 2., 3.]],
[[ 1., 3., 2.],
[ 2., 1., 1.],
[ 2., 1., 1.]]])