【PyTorch】Tensor.masked_scatter() 方法:将另一个张量的值“散布”到当前张量的掩码位置(True)

Tensor.masked_scatter() 方法详解

在 PyTorch 中,masked_scatter() 是一个用于 将另一个张量的值“散布”到当前张量的掩码位置(True 的方法。它是一个 按掩码进行定点赋值 的操作,常用于需要根据条件更新部分张量的场景。


1. 函数原型

Tensor.masked_scatter(mask, source) → Tensor
参数 说明
mask 与输入张量 可广播 的布尔张量,True 代表要替换的位置
source 1D 或与输入张量广播兼容的张量,从中按顺序取值填入 mask 为 True 的位置
返回 替换后的新张量(也可以原地更新)

2. 基本功能说明

  • mask == True 的位置,将 source 中对应数量的值填入原张量;
  • 多用于 选择性地更新张量部分元素
  • source.numel() 必须与 mask.sum() 的个数一致。

3. 基本示例

import torch

x = torch.tensor([1, 2, 3, 4, 5])
mask = torch.tensor([False, True, True, False, True])
source = torch.tensor([10, 20, 30])

# 替换 x 中 mask==True 的位置
result = x.masked_scatter(mask, source)
print(result)

输出:

tensor([ 1, 10, 20,  4, 30])

解释:

  • mask 中有 3 个 True,所以 source 必须有 3 个元素;
  • 这些值依次赋值给 x 中对应位置。

4. 多维张量示例

x = torch.zeros(2, 3)
mask = torch.tensor([[True, False, True],
                     [False, True, False]])
source = torch.tensor([1, 2, 3])

result = x.masked_scatter(mask, source)
print(result)

输出:

tensor([[1., 0., 2.],
        [0., 3., 0.]])

5. 与其他 mask 操作的区别

方法 作用 特点
masked_fill() mask=True 的位置替换为同一个标量 只能替换为固定值
masked_scatter() mask=True 的位置替换为另一个张量中的元素 替换为不同值
masked_select() 提取 mask=True 的元素,输出 1D 张量 用于筛选而非赋值

6. 注意事项

  • mask 必须是布尔类型。
  • source 的元素数量必须 正好等于 mask.sum()
  • 返回的是 新张量,也可以通过 in-place 操作(加 _)实现:
    x.masked_scatter_(mask, source)
    

7. 应用场景

  • 对部分元素进行更新(如根据条件替换模型输出);
  • 在训练中实现条件性数据替换(如选择性修正);
  • 实现自定义 attention 或 token 替换逻辑

8. 总结

特性 说明
功能 将另一张量的值按掩码位置“散布”到当前张量中
mask True 的位置将被赋值
source mask.sum() 个数相等的张量
适用场景 条件更新、数据修复、动态替换操作
返回 新张量(或用 _ 结尾进行原地操作)

masked_scatter() 在某些对性能敏感或实现自定义数据更新逻辑的深度学习任务中非常实用。