当在PyTorch中需要根据指定的索引来将值分散(scatter)到张量的特定位置时,可以使用scatter函数。这个函数在处理非连续索引的情况下非常有用。
函数写法为:
target.scatter(dim, index, src)
一 函数介绍
scatter(input, dim, index, src):将src中数据根据index中的索引按照dim的方向填进input。可以理解成放置元素或者修改元素
target
:即目标张量,将在该张量上进行映射
src
:即源张量,将把该张量上的元素逐个映射到目标张量上dim
:指定轴方向,定义了填充方式。对于二维张量,dim=0
表示逐列进行行填充,而dim=1
表示逐行进行列填充index
: 按照轴方向,在target
张量中需要填充的位置
scatter()函数演示:
1. 创建示例目标张量(target tensor),目标张量在dim维度上不小于源张量,其他维度上一般与源张量相同
import torch
# 创建一个大小为[4, 4]的零张量
target = torch.zeros(4, 4)
2. 创建索引张量(index tensor),该张量确定在目标张量中的哪些位置进行散点操作。索引张量的形状通常与源张量相同,索引内容可以重复,未被扫描到的值不变, 重复扫描使用最后一个位置的值, 但数据类型为整数。
# 创建一个大小为[4, 4]的索引张量
index = torch.tensor([[0, 1, 2, 3],
[1, 2, 3, 0],
[2, 3, 0, 1],
[3, 0, 1, 2]])
3. 创建源张量(src tensor),该张量包含你要分散到目标张量中的值。源张量的形状通常与目标张量相同。
# 创建一个大小为[4, 4]的值张量
values = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]], dtype=torch.float32)
4. 使用scatter函数进行散点操作:
# 使用scatter进行散点操作
result = target.scatter(0, index, values)
print(result)
5. 得到结果:
tensor([[ 1., 14., 11., 8.],
[ 5., 2., 15., 12.],
[ 9., 6., 3., 16.],
[13., 10., 7., 4.]])
解释:索引张量第一行第一列索引为0,那么将对应位置的源张量的值‘1’散布到目标张量的行索引为0的对应位置,因此目标向量第一行第一列为1
索引向量第二行第三列索引为3,对应位置源张量值为7,则将4散布到目标张量的行索引为3的对应位置,因此目标向量第四行第三列为7