在 PyTorch 中,torch.argmax
是一个函数,用于返回张量中某个维度上最大值的索引,即张量中在某个维度上具有最大值的元素所在的位置。该函数的语法如下:
torch.argmax(input, dim=None, keepdim=False)
其中:
input
:要在其中查找最大元素的张量。dim
:如果指定了此参数,则在指定的维度上查找最大元素。否则,将在整个张量上查找最大元素。keepdim
:如果将此参数设置为 True,则将输出张量的形状保持与输入张量的形状相同。
例如,如果我们有一个形状为 (3, 4) 的张量 x
,并且想在第 0 个维度上查找最大元素,则可以使用以下代码:
import torch
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
max_indices = torch.argmax(x, dim=0)
print(max_indices) # tensor([2, 2, 2, 2])
在这里,第一个维度的大小为 3,第二个维度的大小为 4。我们在第 0 个维度上使用 torch.argmax
函数来查找 x
中最大值所在的索引。函数返回的张量 max_indices
包含四个元素,分别代表第 0 个维度上最大值所在的位置(在本例中,每个元素的值都是 2,表示第 2 行是第一列、第二列、第三列和第四列中的最大值所在的位置)。
如果指定了 keepdim=True
,则输出张量的形状将与输入张量的形状相同,只是在指定的维度上将大小设置为 1:
max_indices = torch.argmax(x, dim=0, keepdim=True)
print(max_indices.size()) # torch.Size([1, 4])