Pytorch 中的 tensor index

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zjucor/article/details/88806213

遇到了用tensor来index另外一个tensor的操作,在Pytorch中右2个比较相似的操作

torch.index_select 和 torch.gather

torch.index_select(inputdimindexout=None) → Tensor

Returns a new tensor which indexes the input tensor along dimension dim using the entries in indexwhich is a LongTensor.

The returned tensor has the same number of dimensions as the original tensor (input). The dimth dimension has the same size as the length of index; other dimensions have the same size as in the original tensor. (输出tensor的指定dim的维度是index的长度,其他维度不变)

NOTE

The returned tensor does not use the same storage as the original tensor. If out has a different shape than expected, we silently change it to the correct shape, reallocating the underlying storage if necessary.

Parameters:
  • input (Tensor) – the input tensor
  • dim (int) – the dimension in which we index
  • index (LongTensor) – the 1-D tensor containing the indices to index
  • out (Tensoroptional) – the output tensor

Example:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

torch.gather(inputdimindexout=Nonesparse_grad=False) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

If input is an n-dimensional tensor with size (x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})(x0​,x1​...,xi−1​,xi​,xi+1​,...,xn−1​) and dim = i, then index must be an nn-dimensional tensor with size (x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})(x0​,x1​,...,xi−1​,y,xi+1​,...,xn−1​) where y \geq 1y≥1and out will have the same size as index(输出tensor和index的维度相同,index维度必须只有一维跟input tensor不同).

Parameters:
  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
  • out (Tensoroptional) – the destination tensor
  • sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.

Example:

>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1,  1],
        [ 4,  3]])

上面似乎都是用一个tensor来index,在一个维度上改变,其他维度保持不变,

似乎也有另外一种用多个tensor共同index的方法,比如:

roi_cls_loc1[t.arange(0, n_sample1).long().cuda(), at.totensor(gt_roi_label1).long()]

猜你喜欢

转载自blog.csdn.net/zjucor/article/details/88806213