gather() in Pytorch

Function introduction

effect

Used to gather the values ​​​​in the input tensor from the specified dimension

parameter

​torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
  • input  ( Tensor ) – input tensor

  • dim  ( int ) – the axis used to index values

  • index  ( LongTensor ) – the index value

  • sparse_grad  ( bool,  optional ) – if True, the gradient of the input tensor will become a sparse tensor

  • out  ( Tensor,  optional ) – the output tensor

Precautions

input and index must have the same dimensions. If d != dim, also require that index.size(d) <= input.size(d) for all dimensions. output has the same shape as index

2D-Tensor example

dim=0

  • First create an input tensor with values ​​from 1 to 16 and reshape
import torch

x = torch.range(1,16).view(4,4)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
"""
  • Then create the index

Such as [[0, 1, 2, 3], [3, 2, 1, 0]]

First see [0, 1, 2, 3], the value inside means to select from the 0th, 1, 2, and 3rd rows respectively , and then because [ 0 , 1, 2, 3] are respectively located in the 0th, 2nd, and 3rd rows in the index 1 , 2 , 3 columns, so the output after indexing is: input[ 0 ][ 0 ] , input[ 1 ][ 1 ] , input[ 2 ][ 2 ] , input[ 3 ][ 3 ] , namely [1. , 6., 11., 16.]

Then see [3, 2, 1, 0], the value inside means to select from the 3rd, 2, 1, 0 rows respectively, and then because [3, 2, 1, 0] are respectively located in the 0th , 0th and 0th rows in the index 1 , 2 , and 3 columns, so the output after indexing is: input[ 3 ][ 0 ] , input[ 2 ][ 1 ] , input[ 1 ][ 2 ] , input[ 0 ][ 3 ] , namely [13. , 10., 7., 4.]

index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
        [3, 2, 1, 0]])
"""
  • The printout results show the same as expected
y = torch.gather(x, dim=0, index=index)
"""
tensor([[ 1.,  6., 11., 16.],
        [13., 10.,  7.,  4.]])
"""

dim=1

  • Create input tensor
import torch

x = torch.range(1,16).view(4,4)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
"""
  • create index

Such as [[0, 1, 2, 3], [3, 2, 1, 0]]

First see [0, 1, 2, 3], the value inside means to select from columns 0, 1, 2, and 3 respectively, and then because [0, 1, 2, 3] is located in row 0 of the index , Therefore, the output after indexing is: input[ 0 ][ 0 ] , input[ 0 ][ 1 ] , input[ 0 ][ 2 ] , input[ 0 ][ 3 ] , namely [1., 2., 3., 4.]

Then see [3, 2, 1, 0], the value inside means to select from the 3rd, 2, 1, 0 columns respectively, and because [3, 2, 1, 0] is located in the first row of the index , So the output after indexing is: input[ 1 ][ 3 ] , input[ 1 ][ 2 ] , input[ 1 ][ 1 ] , input[ 1 ][ 0 ] , namely [8., 7., 6., 5.]

index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
        [3, 2, 1, 0]])
"""
  • Print out the tensor, showing as expected
y = torch.gather(x, dim=1, index=index)
"""
tensor([[1., 2., 3., 4.],
        [8., 7., 6., 5.]])
"""

Summarize

When gathering 2D-tensor, if dim= 0 or 1 , the value in the index indicates that you should first select from a certain row or column , and then locate the value according to the column or row in the index , and you can get the required Gather value

Guess you like

Origin blog.csdn.net/qq_38964360/article/details/131550919