1 background
Last year, I understood the usage of torch.gather() , but this year I saw it and forgot it. I simply sorted out my understanding, so that I can quickly get started after forgetting it in the future.
Official documentation:
![](https://pic4.zhimg.com/v2-f8a19c1c3d3e4da167f9517615bf6857_b.jpg)
![](https://pic4.zhimg.com/80/v2-f8a19c1c3d3e4da167f9517615bf6857_720w.jpg)
The definition of torch.gather() in the official document is very concise
Definition: Get the data of the specified dim and specified index from the original tensor
Seeing this core definition, the basic idea we can easily think of gather()
is actually as simple as fetching values from the complete data by index , for example, fetching values by index from the list below
lst = [1, 2, 3, 4, 5]
value = lst[2] # value = 3
value = lst[2:4] # value = [3, 4]
The above example of value selection is an example of taking a single value or a sequence with a logical order. For the batch tensor data commonly used in deep learning , our requirement may be to select multiple and out-of-order values. At this time, gather()
it is a very Good tool, it can help us extract the data under the specified out-of-order index from the batch tensor, so its use is as follows
Purpose: It is convenient to obtain the data under the specified index from the batch tensor, which is highly customizable and can be out of order
2 combat
Let's find a 3x3 two-dimensional matrix to do an experiment
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
output result
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
2.1 Enter the row vector index and replace the row index (dim=0)
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
output result
tensor([[9, 7, 5]])
The process is shown in the figure
2.2 Enter the row vector index and replace the column index (dim=1)
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
output result
tensor([[5, 4, 3]])
The process is shown in the figure
2.3 Enter the column vector index and replace the column index (dim=1)
index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
output result
tensor([[5],
[7],
[9]])
The process is shown in the figure
2.4 Enter the two-dimensional matrix index and replace the column index (dim=1)
index = torch.tensor([[0, 2],
[1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
output result
tensor([[3, 5],
[7, 8]])
The process is the same as above
, pay attention to the change of the line number during the calculation process
3 Use in Reinforcement Learning DQN
In the code of the DQN page on the PyTorch official website, this is how to get Q ( S t , a ) Q(S_t,a)Q(St,a ) _
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)
where Q ( S t ) Q(S_t)Q(St) , that is, policy_net(state_batch) is a two-dimensional table with shape=(128, 2), and the number of actions is 2
And the corresponding batch action we output through the neural network is Q ( S t , a ) Q(S_t,a)Q(St,a )
At this point, use the gather() function to easily obtainthe Q ( S t , a ) Q(S_t,a)ofthe batch statecorresponding tothe batch actionQ(St,a)
3 summary
From the above typical cases, we can summarize the key points of using torch.gather()
- The shape of the input index is equal to the shape of the output value
- The index value of the input index only replaces the index value of the corresponding dim in the index
- The final output is the value in the original tensor after replacing the index
This article is reprinted from Zhihu, the original URL: Diagram of the torch.gather function in PyTorch