假设数据如下,每行是softmax输出得到的概率,我需要找到最大的概率返回类别,可以使用argmax函数
(1)注意使用argmax函数时,需要将数据转换为tensor类型,否则报错 argmax(): argument 'input' (position 1) must be Tensor, not numpy.ndarray (2)torch.argmax函数需要传递dim参数,dim=1就是在行上求 index = torch.argmax(data_pre, dim=1)
import numpy as np
import pandas as pd
import torch
data_pre = np.loadtxt('./pred.txt')
data_pre = torch.tensor(data_pre)
index = torch.argmax(data_pre, dim=1)
index = np.array(index)
np.savetxt('./cluster.txt', (index))