torch.argmax 的使用

在 PyTorch 中,torch.argmax 是一个函数,用于返回张量中某个维度上最大值的索引,即张量中在某个维度上具有最大值的元素所在的位置。该函数的语法如下:

torch.argmax(input, dim=None, keepdim=False)

其中:

  • input:要在其中查找最大元素的张量。
  • dim:如果指定了此参数,则在指定的维度上查找最大元素。否则,将在整个张量上查找最大元素。
  • keepdim:如果将此参数设置为 True,则将输出张量的形状保持与输入张量的形状相同。

例如,如果我们有一个形状为 (3, 4) 的张量 x,并且想在第 0 个维度上查找最大元素,则可以使用以下代码:

import torch

x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])
max_indices = torch.argmax(x, dim=0)
print(max_indices)  # tensor([2, 2, 2, 2])

在这里,第一个维度的大小为 3,第二个维度的大小为 4。我们在第 0 个维度上使用 torch.argmax 函数来查找 x 中最大值所在的索引。函数返回的张量 max_indices 包含四个元素,分别代表第 0 个维度上最大值所在的位置(在本例中,每个元素的值都是 2,表示第 2 行是第一列、第二列、第三列和第四列中的最大值所在的位置)。

如果指定了 keepdim=True,则输出张量的形状将与输入张量的形状相同,只是在指定的维度上将大小设置为 1:

max_indices = torch.argmax(x, dim=0, keepdim=True)
print(max_indices.size())  # torch.Size([1, 4])

猜你喜欢

转载自blog.csdn.net/djdjdhch/article/details/130639740