torch.max()详解

torch.max()

在这里插入图片描述
pytorch文档中提到:该函数返回一个元组:(值,索引),其中值是给定维度dim中输入张量每行的最大值。索引是找到的每个最大值(argmax)的索引位置。
如果keepdim为True,则输出张量的大小与输入相同,但维度dim中的大小为1。否则dim被压缩,导致输出张量的维数比输入少1。
注:若有多个最大值,则返回第一个最大值的索引

代码演示

a = torch.randn(4, 4)
print(a)
#tensor([[-0.7670, -0.2193,  0.1777,  0.3602],
#        [ 1.0125,  0.8830, -1.1294, -1.8622],
#        [ 1.3611,  1.2073,  1.8415, -1.4175],
#        [-0.7687,  0.6015,  0.1030, -0.1119]])
a1 = torch.max(a)  # 所有元素中最大的
print(a1)
#tensor(1.8415)
a2 = torch.max(a, 0)  # 返回每一列的最大值,及其索引
print(a2)
#torch.return_types.max(
#values=tensor([1.3611, 1.2073, 1.8415, 0.3602]),
#indices=tensor([2, 2, 2, 0]))
a3 = torch.max(a, 1)  # 返回每一行的最大值,及其索引
print(a3)
#torch.return_types.max(
#values=tensor([0.3602, 1.0125, 1.8415, 0.6015]),
#indices=tensor([3, 0, 2, 1]))
a4 = torch.max(a, 1)[0]  # 只返回最大值
print(a4)
#tensor([0.3602, 1.0125, 1.8415, 0.6015])
a5 = torch.max(a, 1)[1]  # 只返回最大值索引
print(a5)
#tensor([3, 0, 2, 1])
a6 = torch.max(a, 1)[1].numpy() # 将结果转化为Numpy格式
print(a6)
#[3 0 2 1]

猜你喜欢

转载自blog.csdn.net/gary101818/article/details/129303844