torch.argmax()函数【求最大值的索引,并让指定维度消失】

torch.argmax(input, dim=None, keepdim=False)

argmax函数:返回指定维度最大值的索引,dim指定某一维度,那么这一维度就会消失返回的所有维度会少这个dim指定的维度,根据这个返回的维度,确定对哪个维度采取argmax操作

例如输入是token_output的维度是(62,320,523):target_len:62【序列最大长度】, 320【batch-size】, 523【词表大小】

output_all_token_id = torch.argmax(token_output, -1).tolist()

这段话的意思就是在让最后一维消失(取每个批次生成概率最大的token),那么就变成(62,320)维度了,意思就是320条生成的文本

简单例子:

假如是二维矩阵:

dim=0意思就是“行”这一维度消失,只剩下列,也就是求每一列中最大值的索引

dim=1意思就是“列”这一维度消失,只剩下行,也就是求每一行中最大值的索引

import torch
a = torch.randn(2, 3)
print(a)

tensor([[-0.3018,  0.3350,  0.8318],
        [ 0.2485,  0.5349, -1.2342]])

# 求所有值中最大值的索引
print(torch.argmax(a))

# dim=0意思就是“行”这一维度消失,只剩下列,也就是求每一列中最大值的索引
print(torch.argmax(a, dim=0))

# dim=1意思就是“列”这一维度消失,只剩下行,也就是求每一行中最大值的索引
print(torch.argmax(a, dim=1))

tensor(2)
tensor([2, 1])
tensor([1, 1, 0])

torch.argmax函数说明_Egozjuer的博客-CSDN博客

猜你喜欢

转载自blog.csdn.net/weixin_43135178/article/details/131666388