我对pytorch中gather函数的一点理解

本文首发于公众号【拇指笔记】

官方文档的解释

torch.gather(input,dim,index,out=None) → Tensor

torch.gather(input, dim, index, out=None) → Tensor

    Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters: 

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

举个例子

import torch

a = torch.Tensor([[1,2],
                 [3,4]])

b = torch.gather(a,1,torch.LongTensor([[0,0],[1,0]]))
#1. 取各个元素行号:[(0,y)(0,y)][(1,y)(1,y)]
#2. 取各个元素值做行号:[(0,0)(0,0)][(1,1)(1,0)]
#3. 根据得到的索引在输入中取值
#[1,1],[4,3]

c = torch.gather(a,0,torch.LongTensor([[0,0],[1,0]]))
#1. 取各个元素列号:[(x,0)(x,1)][(x,0)(x,1)]
#2. 取各个元素值做行号:[(0,0)(0,1)][(1,0)(0,1)]
#3. 根据得到的索引在输入中取值
#[1,2],[3,2]

原理解释

假设输入与上同;index=B;输出为C
B中每个元素分别为b(0,0)=0,b(0,1)=0
b(1,0)=1,b(1,1)=0

如果dim=0(列)
则取B中元素的列号,如:b(0,1)的1
b(0,1)=0,所以C中的c(0,1)=输入的(0,1)处元素2

如果dim=1(行)
则取B中元素的列号,如:b(0,1)的0
b(0,1)=0,所以C中的c(0,1)=输入的(0,0)处元素1

总结如下:
输出 元素 在 输入张量 中的位置为:
输出元素位置取决与同位置的index元素
dim=1时,取同位置的index元素的行号做行号,该位置处index元素做列号
dim=0时,取同位置的index元素的列号做列号,该位置处index元素做行号。

最后根据得到的索引在输入中取值

index类型必须为LongTensor
gather最终的输出变量与index同形。

欢迎各位关注我的公众号,每天更新我的学习笔记,希望大家共同学习,共同进步!
在这里插入图片描述

发布了31 篇原创文章 · 获赞 72 · 访问量 4953

猜你喜欢

转载自blog.csdn.net/weixin_44610644/article/details/104607919