torch.gather() function in PyTorch

1. torch.gather() function

Official document: torch.gather function , definition: get the data of specified dim and specified index from the original tensor.
Seeing this core definition, we can easily think that the basic idea of ​​gather() is actually as simple as fetching values ​​from the complete data by index, for example, fetching values ​​by index from the list below:

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

The above example of value selection is an example of taking a single value or a sequence with a logical order. For the batch tensor data commonly used in deep learning, our requirement may be to select multiple and out-of-order values. At this time, gather() is A very good tool, it can help us extract the data under the specified out-of-order index from the batch tensor, so its use is as follows: Purpose: It is
convenient to obtain the data under the specified index from the batch tensor, which is highly customized , can be out of order.

2. Examples

Find a 3x3 two-dimensional matrix to do an experiment:

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

Output result:

tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

2.1 Enter the row vector index and replace the row index (dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

The output is as follows:

tensor([[9, 7, 5]])

The process is shown in the figure:
torch.gather

2.2 Enter the row vector index and replace the column index (dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

output result

tensor([[5, 4, 3]])

The process is shown in the figure:
torch.gather

2.3 Enter the column vector index and replace the column index (dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

The output is as follows:

tensor([[5],
        [7],
        [9]])

The process is shown in the figure:
torch.gather

2.4 Enter the two-dimensional matrix index and replace the column index (dim=1)

index = torch.tensor([[0, 2], 
                      [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

output result:

tensor([[3, 5],
        [7, 8]])

The process is as above

3. Summary

From the above example, the main points of using torch.gather() can be summarized:

  1. The shape of the input index is equal to the shape of the output value
  2. The index value of the input index only replaces the index value of the corresponding dim in the index
  3. The final output is the value in the original tensor after replacing the index

4. Simple to understand in one sentence

torch.gather的理解:
index=[ [x1,x2,x2],
[y1,y2,y2],
[z1,z2,z3] ]

如果dim=0
填入方式
[ [(x1,0),(x2,1),(x3,2)]
[(y1,0),(y2,1),(y3,2)]
[(z1,0),(z2,1),(z3,2)] ]

如果dim=1
[ [(0,x1),(0,x2),(0,x3)]
[(1,y1),(1,y2),(1,y3)]
[(2,z1),(2,z2),(2,z3)] ]

5. Reference link

https://zhuanlan.zhihu.com/p/352877584

Guess you like

Origin blog.csdn.net/flyingluohaipeng/article/details/128060091#comments_27151922