See the official documentation: https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
Definition: Get the data of specified dim and specified index from the original tensor
Purpose: It is convenient to obtain the data under the specified index from batch tensor. The index is highly customized and can be done out of order.
Index is a row vector, replacing index dim = 0
dim=0, replace rows with index
Index is a row vector, replacing index dim = 1
dim=1, replace the column with index
Initial index dim=1
(0,0) 2 (0,2)
(0,1) 1 (0,1)
(0,2) 0 (0,0)
Why is there (0, 0) (0, 1) (0, 2)
Let’s look at index = [[2, 1, 0]]
is 1×3, that is, its subscript for the element
If index is a column vector, replace the index with dim = 0 and dim = 1
For the two-dimensional matrix index, and replace the index (dim = 1)
calculate:
in conclusion:
- The shape of the input index is equal to the shape of the output value
- The index value of the input index only replaces the index value of the corresponding dim in the index.
- The final output is the value in the original tensor after replacing the index.