pytorch之scatter() 函数

当在PyTorch中需要根据指定的索引来将值分散(scatter)到张量的特定位置时,可以使用scatter函数。这个函数在处理非连续索引的情况下非常有用。

函数写法为:

target.scatter(dim, index, src)

一 函数介绍

scatter(input, dim, index, src):将src中数据根据index中的索引按照dim的方向填进input。可以理解成放置元素或者修改元素

  • target:即目标张量,将在该张量上进行映射
  • src:即源张量,将把该张量上的元素逐个映射到目标张量上
  • dim:指定轴方向,定义了填充方式。对于二维张量,dim=0表示逐列进行行填充,而dim=1表示逐行进行列填充
  • index: 按照轴方向,在target张量中需要填充的位置

scatter()函数演示:

1. 创建示例目标张量(target tensor),目标张量在dim维度上不小于源张量,其他维度上一般与源张量相同 

import torch
 
# 创建一个大小为[4, 4]的零张量
target = torch.zeros(4, 4)

 2. 创建索引张量(index tensor),该张量确定在目标张量中的哪些位置进行散点操作。索引张量的形状通常与源张量相同,索引内容可以重复,未被扫描到的值不变, 重复扫描使用最后一个位置的值, 但数据类型为整数。

# 创建一个大小为[4, 4]的索引张量
index = torch.tensor([[0, 1, 2, 3],
                     [1, 2, 3, 0],
                     [2, 3, 0, 1],
                     [3, 0, 1, 2]])

3. 创建源张量(src tensor),该张量包含你要分散到目标张量中的值。源张量的形状通常与目标张量相同。 

# 创建一个大小为[4, 4]的值张量
values = torch.tensor([[1, 2, 3, 4],
                      [5, 6, 7, 8],
                      [9, 10, 11, 12],
                      [13, 14, 15, 16]], dtype=torch.float32)

4. 使用scatter函数进行散点操作:

# 使用scatter进行散点操作
result = target.scatter(0, index, values)
print(result)

5. 得到结果:

tensor([[ 1., 14., 11.,  8.],
        [ 5.,  2., 15., 12.],
        [ 9.,  6.,  3., 16.],
        [13., 10.,  7.,  4.]])

解释:索引张量第一行第一列索引为0,那么将对应位置的源张量的值‘1’散布到目标张量的行索引为0的对应位置,因此目标向量第一行第一列为1

索引向量第二行第三列索引为3,对应位置源张量值为7,则将4散布到目标张量的行索引为3的对应位置,因此目标向量第四行第三列为7

猜你喜欢

转载自blog.csdn.net/weixin_41147796/article/details/138307276