版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_26114733/article/details/88088077
>>> a
tensor([[ 0.9918, 0.4911, 1.4912, -1.8491],
[ 0.1257, -0.4406, 0.3371, 0.1205],
[ 0.3064, -0.8198, 1.2851, 0.2486]])
>>> b
tensor([[0, 1],
[1, 2],
[2, 2]])
>>> a.unsqueeze(1).expand(3,2,4).gather(dim=0,index=b.unsqueeze(2).expand(3,2,4))
tensor([[[ 0.9918, 0.4911, 1.4912, -1.8491],
[ 0.1257, -0.4406, 0.3371, 0.1205]],
[[ 0.1257, -0.4406, 0.3371, 0.1205],
[ 0.3064, -0.8198, 1.2851, 0.2486]],
[[ 0.3064, -0.8198, 1.2851, 0.2486],
[ 0.3064, -0.8198, 1.2851, 0.2486]]])