Diagram of the torch.gather function in PyTorch

1 background

Last year, I understood the usage of torch.gather() , but this year I saw it and forgot it. I simply sorted out my understanding, so that I can quickly get started after forgetting it in the future.

Official documentation:

The definition of torch.gather() in the official document is very concise

Definition: Get the data of the specified dim and specified index from the original tensor

Seeing this core definition, the basic idea we can easily think of gather()is actually as simple as fetching values ​​from the complete data by index , for example, fetching values ​​by index from the list below

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

The above example of value selection is an example of taking a single value or a sequence with a logical order. For the batch tensor data commonly used in deep learning , our requirement may be to select multiple and out-of-order values. At this time, gather()it is a very Good tool, it can help us extract the data under the specified out-of-order index from the batch tensor, so its use is as follows

Purpose: It is convenient to obtain the data under the specified index from the batch tensor, which is highly customizable and can be out of order

2 combat

Let's find a 3x3 two-dimensional matrix to do an experiment

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

output result

tensor([[ 3,  4,  5],
		[ 6,  7,  8],
		[ 9, 10, 11]])

2.1 Enter the row vector index and replace the row index (dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

output result

tensor([[9, 7, 5]])

The process is shown in the figure
insert image description here

2.2 Enter the row vector index and replace the column index (dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

output result

tensor([[5, 4, 3]])

The process is shown in the figure
insert image description here

2.3 Enter the column vector index and replace the column index (dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

output result

tensor([[5],
        [7],
        [9]])

The process is shown in the figure
insert image description here

2.4 Enter the two-dimensional matrix index and replace the column index (dim=1)

index = torch.tensor([[0, 2], 
                      [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

output result

tensor([[3, 5],
        [7, 8]])

The process is the same as above
, pay attention to the change of the line number during the calculation process

3 Use in Reinforcement Learning DQN

In the code of the DQN page on the PyTorch official website, this is how to get Q ( S t , a ) Q(S_t,a)Q(St,a ) _

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

where Q ( S t ) Q(S_t)Q(St) , that is, policy_net(state_batch) is a two-dimensional table with shape=(128, 2), and the number of actions is 2
insert image description here

And the corresponding batch action we output through the neural network is Q ( S t , a ) Q(S_t,a)Q(St,a )
insert image description here
At this point, use the gather() function to easily obtainthe Q ( S t , a ) Q(S_t,a)ofthe batch statecorresponding tothe batch actionQ(St,a)

3 summary

From the above typical cases, we can summarize the key points of using torch.gather()

  • The shape of the input index is equal to the shape of the output value
  • The index value of the input index only replaces the index value of the corresponding dim in the index
  • The final output is the value in the original tensor after replacing the index

This article is reprinted from Zhihu, the original URL: Diagram of the torch.gather function in PyTorch

Guess you like

Origin blog.csdn.net/weixin_46707326/article/details/120424556