torch.max函数解释

import torch

outputs = torch.Tensor([[ 0.0052, -0.8156,  0.4052, -2.0467,  0.9094],
        [ 0.5136,  0.9679, -0.4025, -1.4280, -0.2329],
        [-1.5564, -2.9252,  1.5007, -1.8669,  2.6327],
        [-1.3189, -3.0958,  1.7134, -2.4758,  2.1970],
        [ 0.3545,  1.0336, -1.0870, -0.5224, -0.2577],
        [ 0.9172,  0.5531, -0.6612, -1.8037, -0.0375],
        [-1.8091, -3.4890,  1.7762, -2.2764,  2.9902],
        [-1.1849,  1.9803, -2.5143,  2.7419, -0.5074],
        [ 0.2643,  1.1787, -0.6099, -1.1116, -0.1860],
        [-0.1398,  1.2845, -1.3439,  0.4861, -0.4152],
        [ 0.3470,  0.1520, -0.1800, -1.8029,  0.2673],
        [-0.3466,  0.7583, -1.4018,  0.7416, -0.1349],
        [ 0.0110,  1.3657, -1.5803,  0.3905, -0.4626],
        [ 0.6834,  1.2577, -0.4732, -1.6061, -0.2986],
        [-1.4266, -3.1705,  1.6706, -2.4082,  2.4462],
        [ 0.9110,  1.6789, -1.0384, -1.3190, -0.6248],
        [ 0.6784,  1.4374, -0.5409, -1.5479, -0.4078],
        [ 0.0870,  1.6230, -1.5944,  0.0156, -0.5264],
        [ 0.4745,  0.5972, -0.3581, -0.7977, -0.1117],
        [-0.9344, -2.2650,  0.7877, -1.7376,  2.0582],
        [ 0.8842,  1.8267, -0.7559, -1.6596, -0.5955],
        [-0.7785, -0.2684, -0.9535,  0.3781,  0.5765],
        [ 0.3231,  0.3184, -0.1985, -1.1899,  0.0858],
        [ 0.8716,  1.9445, -0.8713, -1.7622, -0.6823],
        [-1.0959,  1.8133, -2.4571,  2.5514, -0.4569],
        [ 0.5522,  1.1478, -0.4801, -1.0794, -0.3068],
        [ 1.1927,  1.4754, -0.7350, -1.9110, -0.3887],
        [-1.6291, -1.5367, -0.1959, -0.2711,  1.4813],
        [ 0.6408,  1.1346, -0.4077, -1.2944, -0.3021],
        [ 0.5837,  1.2832, -0.4584, -1.4342, -0.3770],
        [-0.5513,  1.9288, -2.2596,  1.6180, -0.5218],
        [ 0.7435,  0.5310, -0.5460, -1.1995, -0.1450]])
predict_y = torch.max(outputs, dim=1)#max函数返回两个值,其中一个是values,另外一个值是index。
print(predict_y)

请添加图片描述

猜你喜欢

转载自blog.csdn.net/guoguozgw/article/details/128985800