torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

参考链接: torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
参考链接: topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor)
参考链接: PyTorch使用torch.sort()函数来筛选出前k个最大的项或者筛选出前k个最小的项

在这里插入图片描述

在这里插入图片描述

原文及翻译:

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
函数: torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) 返回的是一个具名元组, 类型是(Tensor类型, LongTensor类型)

    Returns the k largest elements of the given input tensor along 
    a given dimension.
    在给定的维度上将给定的input张量的k个最大的元素返回.

    If dim is not given, the last dimension of the input is chosen.
    如果dim参数没有给出,那么就默认选择张量input的最后一个维度.

    If largest is False then the k smallest elements are returned.
    如果参数largest 是False,那么返回最小的k个元素.

    A namedtuple of (values, indices) is returned, where the indices 
    are the indices of the elements in the original input tensor.
	返回一个具名元组(values, indices),其中indices是返回的元素在原始的
	张量input中的索引位置.
    The boolean option sorted if True, will make sure that the 
    returned k elements are themselves sorted
    布尔类型的选项参数sorted如果是True,那么就会确保返回的k个元素本身
    是有序的.

    Parameters  参数
            input (Tensor) – the input tensor
            input (Tensor类型) – 输入的张量
            k (int) – the k in “top-k”
            k (int类型) – 即“前k个”中的k值
            dim (int, optional) – the dimension to sort along
            dim (int类型, 可选) – 表示沿着那个维度上进行排序
            largest (bool, optional) – controls whether to return 
            largest or smallest elements
            largest (布尔类型, 可选) – 用于控制返回最大的元素还是最小的元素
            sorted (bool, optional) – controls whether to return the 
            elements in sorted order
			sorted (bool类型, 可选) – 用于控制返回的元素是否需要排好序
            out (tuple, optional) – the output tuple of (Tensor, 
            LongTensor) that can be optionally given to be used as 
            output buffers
            out (元组类型, 可选) – 这是输出数据的元组
            (Tensor类型, LongTensor类型),该元组可以可选地给出,用于输出数据
            地缓冲

    Example:  例子:

    >>> x = torch.arange(1., 6.)
    >>> x
    tensor([ 1.,  2.,  3.,  4.,  5.])
    >>> torch.topk(x, 3)
    torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))

代码实验展示:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102

(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000167F5247870>
>>>
>>> data = torch.randint(100,(15,))
>>> data
tensor([63, 48, 14, 47, 28,  5, 80, 68, 88, 61,  6, 84, 82, 87, 59])
>>>
>>> torch.topk(data, k=7, dim=0, largest=True, sorted=False, out=None)
torch.return_types.topk(
values=tensor([88, 87, 84, 82, 80, 68, 63]),
indices=tensor([ 8, 13, 11, 12,  6,  7,  0]))
>>>
>>> result = torch.topk(data, k=7, dim=0, largest=True, sorted=False, out=None)
>>> type(result)
<class 'torch.return_types.topk'>
>>> isinstance(1,int)
True
>>> # 返回的是具名数组的子类
>>> isinstance(result, tuple)
True
>>> result
torch.return_types.topk(
values=tensor([88, 87, 84, 82, 80, 68, 63]),
indices=tensor([ 8, 13, 11, 12,  6,  7,  0]))
>>> result[0]
tensor([88, 87, 84, 82, 80, 68, 63])
>>> result[1]
tensor([ 8, 13, 11, 12,  6,  7,  0])
>>> result.values
tensor([88, 87, 84, 82, 80, 68, 63])
>>> result.indices
tensor([ 8, 13, 11, 12,  6,  7,  0])
>>> type(result[0])
<class 'torch.Tensor'>
>>> type(result[1])
<class 'torch.Tensor'>
>>> result[0].type()
'torch.LongTensor'
>>> result[1].type()
'torch.LongTensor'
>>>
>>> a = torch.randn(15)
>>> a
tensor([ 2.0655, -2.1909, -1.4939, -0.9636,  0.9066, -0.3035,  0.6827, -0.3857,
        -0.6579, -1.2139, -1.5293, -1.8297, -1.8935,  0.4766, -0.9571])
>>> result = torch.topk(a, k=7, dim=0, largest=True, sorted=False, out=None)
>>> result
torch.return_types.topk(
values=tensor([ 2.0655,  0.9066,  0.6827,  0.4766, -0.3035, -0.3857, -0.6579]),
indices=tensor([ 0,  4,  6, 13,  5,  7,  8]))
>>> type(result[0])
<class 'torch.Tensor'>
>>> type(result[1])
<class 'torch.Tensor'>
>>> result[0].type()
'torch.FloatTensor'
>>> result[1].type()
'torch.LongTensor'
>>>
>>>
>>>
>>> torch.topk(data, k=7, dim=0, largest=True, sorted=False, out=None)
torch.return_types.topk(
values=tensor([88, 87, 84, 82, 80, 68, 63]),
indices=tensor([ 8, 13, 11, 12,  6,  7,  0]))
>>>
>>> torch.topk(data, k=7, dim=0, largest=True, sorted=True, out=None)
torch.return_types.topk(
values=tensor([88, 87, 84, 82, 80, 68, 63]),
indices=tensor([ 8, 13, 11, 12,  6,  7,  0]))
>>>
>>> torch.topk(data, k=7, dim=0, largest=False, sorted=False, out=None)
torch.return_types.topk(
values=tensor([ 5,  6, 14, 28, 47, 48, 59]),
indices=tensor([ 5, 10,  2,  4,  3,  1, 14]))
>>>
>>> torch.topk(data, k=7, dim=0, largest=False, sorted=True, out=None)
torch.return_types.topk(
values=tensor([ 5,  6, 14, 28, 47, 48, 59]),
indices=tensor([ 5, 10,  2,  4,  3,  1, 14]))
>>>
>>>
>>> data = a
>>> data
tensor([ 2.0655, -2.1909, -1.4939, -0.9636,  0.9066, -0.3035,  0.6827, -0.3857,
        -0.6579, -1.2139, -1.5293, -1.8297, -1.8935,  0.4766, -0.9571])
>>>
>>> torch.topk(data, k=7, dim=0, largest=True, sorted=False, out=None)
torch.return_types.topk(
values=tensor([ 2.0655,  0.9066,  0.6827,  0.4766, -0.3035, -0.3857, -0.6579]),
indices=tensor([ 0,  4,  6, 13,  5,  7,  8]))
>>>
>>> torch.topk(data, k=7, dim=0, largest=True, sorted=True, out=None)
torch.return_types.topk(
values=tensor([ 2.0655,  0.9066,  0.6827,  0.4766, -0.3035, -0.3857, -0.6579]),
indices=tensor([ 0,  4,  6, 13,  5,  7,  8]))
>>>
>>> torch.topk(data, k=7, dim=0, largest=False, sorted=False, out=None)
torch.return_types.topk(
values=tensor([-2.1909, -1.8935, -1.8297, -1.5293, -1.4939, -1.2139, -0.9636]),
indices=tensor([ 1, 12, 11, 10,  2,  9,  3]))
>>>
>>> torch.topk(data, k=7, dim=0, largest=False, sorted=True, out=None)
torch.return_types.topk(
values=tensor([-2.1909, -1.8935, -1.8297, -1.5293, -1.4939, -1.2139, -0.9636]),
indices=tensor([ 1, 12, 11, 10,  2,  9,  3]))
>>>
>>> data
tensor([ 2.0655, -2.1909, -1.4939, -0.9636,  0.9066, -0.3035,  0.6827, -0.3857,
        -0.6579, -1.2139, -1.5293, -1.8297, -1.8935,  0.4766, -0.9571])
>>>
>>>
>>>

猜你喜欢

转载自blog.csdn.net/m0_46653437/article/details/112914482