Function introduction
effect
Used to gather the values in the input tensor from the specified dimension
parameter
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
-
input ( Tensor ) – input tensor
-
dim ( int ) – the axis used to index values
-
index ( LongTensor ) – the index value
-
sparse_grad ( bool, optional ) – if True, the gradient of the input tensor will become a sparse tensor
-
out ( Tensor, optional ) – the output tensor
Precautions
input and index must have the same dimensions. If d != dim, also require that index.size(d) <= input.size(d) for all dimensions. output has the same shape as index
2D-Tensor example
dim=0
- First create an input tensor with values from 1 to 16 and reshape
import torch
x = torch.range(1,16).view(4,4)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
"""
- Then create the index
Such as [[0, 1, 2, 3], [3, 2, 1, 0]]
First see [0, 1, 2, 3], the value inside means to select from the 0th, 1, 2, and 3rd rows respectively , and then because [ 0 , 1, 2, 3] are respectively located in the 0th, 2nd, and 3rd rows in the index 1 , 2 , 3 columns, so the output after indexing is: input[ 0 ][ 0 ] , input[ 1 ][ 1 ] , input[ 2 ][ 2 ] , input[ 3 ][ 3 ] , namely [1. , 6., 11., 16.]
Then see [3, 2, 1, 0], the value inside means to select from the 3rd, 2, 1, 0 rows respectively, and then because [3, 2, 1, 0] are respectively located in the 0th , 0th and 0th rows in the index 1 , 2 , and 3 columns, so the output after indexing is: input[ 3 ][ 0 ] , input[ 2 ][ 1 ] , input[ 1 ][ 2 ] , input[ 0 ][ 3 ] , namely [13. , 10., 7., 4.]
index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
[3, 2, 1, 0]])
"""
- The printout results show the same as expected
y = torch.gather(x, dim=0, index=index)
"""
tensor([[ 1., 6., 11., 16.],
[13., 10., 7., 4.]])
"""
dim=1
- Create input tensor
import torch
x = torch.range(1,16).view(4,4)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
"""
- create index
Such as [[0, 1, 2, 3], [3, 2, 1, 0]]
First see [0, 1, 2, 3], the value inside means to select from columns 0, 1, 2, and 3 respectively, and then because [0, 1, 2, 3] is located in row 0 of the index , Therefore, the output after indexing is: input[ 0 ][ 0 ] , input[ 0 ][ 1 ] , input[ 0 ][ 2 ] , input[ 0 ][ 3 ] , namely [1., 2., 3., 4.]
Then see [3, 2, 1, 0], the value inside means to select from the 3rd, 2, 1, 0 columns respectively, and because [3, 2, 1, 0] is located in the first row of the index , So the output after indexing is: input[ 1 ][ 3 ] , input[ 1 ][ 2 ] , input[ 1 ][ 1 ] , input[ 1 ][ 0 ] , namely [8., 7., 6., 5.]
index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
[3, 2, 1, 0]])
"""
- Print out the tensor, showing as expected
y = torch.gather(x, dim=1, index=index)
"""
tensor([[1., 2., 3., 4.],
[8., 7., 6., 5.]])
"""
Summarize
When gathering 2D-tensor, if dim= 0 or 1 , the value in the index indicates that you should first select from a certain row or column , and then locate the value according to the column or row in the index , and you can get the required Gather value