torch.gather() function

See the official documentation: https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather

Definition: Get the data of specified dim and specified index from the original tensor

Purpose: It is convenient to obtain the data under the specified index from batch tensor. The index is highly customized and can be done out of order.

Index is a row vector, replacing index dim = 0

dim=0, replace rows with index

Index is a row vector, replacing index dim = 1

 

 

 dim=1, replace the column with index

   Initial index dim=1

(0,0)         2               (0,2)

(0,1)         1               (0,1)

(0,2)         0               (0,0)

Why is there (0, 0) (0, 1) (0, 2)

Let’s look at index = [[2, 1, 0]]

is 1×3, that is, its subscript for the element

If index is a column vector, replace the index with dim = 0 and dim = 1

 

For the two-dimensional matrix index, and replace the index (dim = 1)

calculate:

 

 

in conclusion:

  • The shape of the input index is equal to the shape of the output value
  • The index value of the input index only replaces the index value of the corresponding dim in the index.
  • The final output is the value in the original tensor after replacing the index.

Guess you like

Origin blog.csdn.net/weixin_43537097/article/details/132457209