torch.stack(list)

torch.stack(list,0)

list 中的每个元素为tensor 中第0维度的每个元素

import torch

a = torch.Tensor([[1, 3, 2], [1, 3, 2]])
b = torch.Tensor([[2, 1, 1], [2, 1, 1]])
c = torch.Tensor([[3, 2, 3], [2, 1, 1]])

my_list = [a, b, c]
print(torch.stack(my_list, 0))


》》》
tensor([[[ 1.,  3.,  2.],
         [ 1.,  3.,  2.]],

        [[ 2.,  1.,  1.],
         [ 2.,  1.,  1.]],

        [[ 3.,  2.,  3.],
         [ 2.,  1.,  1.]]])

troch.stack(list,1)  list中每个元素的第0维元素-第n维元素各成一组

print(torch.stack(my_list, 1))

》》》
tensor([[[ 1.,  3.,  2.],
         [ 2.,  1.,  1.],
         [ 3.,  2.,  3.]],

        [[ 1.,  3.,  2.],
         [ 2.,  1.,  1.],
         [ 2.,  1.,  1.]]])

猜你喜欢

转载自blog.csdn.net/Z_lbj/article/details/85012764