Tensor.masked_scatter()
方法详解
在 PyTorch 中,masked_scatter()
是一个用于 将另一个张量的值“散布”到当前张量的掩码位置(True
) 的方法。它是一个 按掩码进行定点赋值 的操作,常用于需要根据条件更新部分张量的场景。
1. 函数原型
Tensor.masked_scatter(mask, source) → Tensor
参数 | 说明 |
---|---|
mask |
与输入张量 可广播 的布尔张量,True 代表要替换的位置 |
source |
1D 或与输入张量广播兼容的张量,从中按顺序取值填入 mask 为 True 的位置 |
返回 | 替换后的新张量(也可以原地更新) |
2. 基本功能说明
- 在
mask == True
的位置,将source
中对应数量的值填入原张量; - 多用于 选择性地更新张量部分元素;
source.numel()
必须与mask.sum()
的个数一致。
3. 基本示例
import torch
x = torch.tensor([1, 2, 3, 4, 5])
mask = torch.tensor([False, True, True, False, True])
source = torch.tensor([10, 20, 30])
# 替换 x 中 mask==True 的位置
result = x.masked_scatter(mask, source)
print(result)
输出:
tensor([ 1, 10, 20, 4, 30])
解释:
mask
中有 3 个True
,所以source
必须有 3 个元素;- 这些值依次赋值给
x
中对应位置。
4. 多维张量示例
x = torch.zeros(2, 3)
mask = torch.tensor([[True, False, True],
[False, True, False]])
source = torch.tensor([1, 2, 3])
result = x.masked_scatter(mask, source)
print(result)
输出:
tensor([[1., 0., 2.],
[0., 3., 0.]])
5. 与其他 mask 操作的区别
方法 | 作用 | 特点 |
---|---|---|
masked_fill() |
将 mask=True 的位置替换为同一个标量 |
只能替换为固定值 |
masked_scatter() |
将 mask=True 的位置替换为另一个张量中的元素 |
替换为不同值 |
masked_select() |
提取 mask=True 的元素,输出 1D 张量 |
用于筛选而非赋值 |
6. 注意事项
mask
必须是布尔类型。source
的元素数量必须 正好等于mask.sum()
。- 返回的是 新张量,也可以通过
in-place
操作(加_
)实现:x.masked_scatter_(mask, source)
7. 应用场景
- 对部分元素进行更新(如根据条件替换模型输出);
- 在训练中实现条件性数据替换(如选择性修正);
- 实现自定义 attention 或 token 替换逻辑。
8. 总结
特性 | 说明 |
---|---|
功能 | 将另一张量的值按掩码位置“散布”到当前张量中 |
mask | True 的位置将被赋值 |
source | 与 mask.sum() 个数相等的张量 |
适用场景 | 条件更新、数据修复、动态替换操作 |
返回 | 新张量(或用 _ 结尾进行原地操作) |
masked_scatter()
在某些对性能敏感或实现自定义数据更新逻辑的深度学习任务中非常实用。