【Pytorch】Tensor的缩减操作

Tensor的缩减操作

Tensor的主要运算操作通常分为四大类:

  1. Reshaping operations(重塑操作)
  2. Element-wise operations(元素操作)
  3. Reduction operations(缩减操作)
  4. Access operations(访问操作)

缩减操作

​ 一个张量的缩减操作是一个减少张量中包含的元素数量的操作,其实质就是允许我们对单个张量中的元素执行操作。常见的缩减操作主要有:

  1. sum
  2. prod
  3. mean
  4. std
  5. max
  6. argmax

下面以sum、max和argmax操作进行说明,演示的张量如下,其他操作可以类比进行理解。

t = torch.tensor([
    [1, 1, 1, 1],
    [2, 2, 2, 2],
    [3, 3, 3, 3]
], dtype=torch.float32)

sum操作

sum操作是将张量中的每一个元素进行累加的操作,现在我们通过累加前后的元素个数对比来体会缩减操作的含义。

(1)检测操作之前的元素个数

t.numel()

显示结果:

12

(2)执行求和操作

t.sum()

显示结果:

tensor(24.)

(3)比较执行求和操作前后的元素个数

t.sum().numel() < t.numel()

显示结果:

True

由于该操作减少了许多元素,所以求和操作是一个缩减操作。

那么,现在需要思考的是,缩减操作总是会变成一个带有单个元素的张量吗?

其实不是的,我们可以传递维度参数的值来减少特定的轴。

(4)沿着第一个轴来进行求和

t.sum(dim=0)

显示结果:

tensor([6., 6., 6., 6.])

其相加过程就是取第一个轴的所有元素的总和,代码演示如下:

print(t[0])
print(t[1])
print(t[2])
print(t[0]+t[1]+t[2])

显示结果:

tensor([1., 1., 1., 1.])
tensor([2., 2., 2., 2.])
tensor([3., 3., 3., 3.])
tensor([6., 6., 6., 6.])

(5)沿着第二个轴来进行求和

t.sum(dim=1)

显示结果:

tensor([ 4.,  8., 12.])

其相加过程就是取第二个轴的所有元素的总和,代码演示如下:

print(t[0].sum())
print(t[1].sum())
print(t[2].sum())

显示结果:

tensor(4.)
tensor(8.)
tensor(12.)

argmax操作

简化操作,其作用是返回张量中元素最大值的索引位置,当我们在一个张量上调用argmax方法时,这个张量被缩减为一个新的张量,它包含一个单独的索引值,指示着张量里面的最大值。
在这里插入图片描述

现假设存在以下张量:

t = torch.tensor([
    [1, 0, 0, 2],
    [0, 3, 3, 0],
    [4, 0, 0, 5]
], dtype=torch.float32)

(1)处理整个张量

print(t.max())
print(t.argmax())

显示结果:

tensor(5.)
tensor(11)

可以看到 t.argmax() 返回的结果是11,其实这个11是先将张量进行flatten操作后再取索引的结果:

t.flatten()

显示结果:

tensor([1., 0., 0., 2., 0., 3., 3., 0., 4., 0., 0., 5.])

(2)处理特定的轴

​ 指定第一个轴

print(t.max(dim=0))
print(t.argmax(dim=0))

​ 显示结果:

torch.return_types.max(
values=tensor([4., 3., 3., 5.]),
indices=tensor([2, 1, 1, 2]))
tensor([2, 1, 1, 2])

​ 指定第二个轴

print(t.max(dim=1))
print(t.argmax(dim=1))

​ 显示结果:

torch.return_types.max(
values=tensor([2., 3., 5.]),
indices=tensor([3, 1, 3]))
tensor([3, 1, 3])

猜你喜欢

转载自blog.csdn.net/zzy_NIC/article/details/119567319