torch.max()函数==》返回该维度的最大值以及该维度最大值对应的索引

今天在学习TTSR的过程总遇到了一行代码,我发现max()函数竟然可以返回两个值,于是我决定重新学习一下这个函数

R_lv3_star, R_lv3_star_arg = torch.max(R_lv3, dim=1) #[N, H*W]  hi


 1、基础用法:

首先是 torch.max()的基础用法,输入一个张量,返回一个确定的最大值

torch.max(input) → Tensor

Example:

>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.6763,  0.7445, -2.2369]])
>>> torch.max(a)
tensor(0.7445)

 2、深度用法:

torch.max(inputdimkeepdim=False*out=None)

按维度dim 返回最大值,并且返回索引。

Parameters

  • input (Tensor) – the input tensor.

  • dim (int) – the dimension to reduce.

  • keepdim (bool) – whether the output tensor has dim retained or not. Default: False.

Keyword Arguments

out (tupleoptional) – the result tuple of two output tensors (max, max_indices),返回的最大值和索引各是一个tensor,分别表示该维度的最大值,以及该维度最大值的索引,一起构成元组(Tensor, LongTensor)

Example:

torch.max(a,0)返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)。返回的最大值和索引各是一个tensor,一起构成元组(Tensor, LongTensor)

a = torch.randn(4, 4)
print(a)
print(torch.max(a,0))


tensor([[ 0.7439,  2.2739, -2.7576, -0.0676],
        [-0.7755, -0.6696,  0.3009, -1.4939],
        [-0.9244,  2.7325,  1.7982,  1.2904],
        [-0.9091, -0.1857, -1.3392, -1.2928]])
torch.return_types.max(
values=tensor([0.7439, 2.7325, 1.7982, 1.2904]),
indices=tensor([0, 2, 2, 2]))

torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)

>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))

Pytorch笔记torch.max() - 知乎

torch.max — PyTorch 1.10 documentation

猜你喜欢

转载自blog.csdn.net/weixin_43135178/article/details/123257024