【PyTorch】Tensor.masked_fill() 方法:根据布尔掩码(mask)将张量中部分元素替换为指定值

Tensor.masked_fill() 方法详解

在 PyTorch 中,masked_fill() 是一个用于 根据布尔掩码(mask)将张量中部分元素替换为指定值 的方法。它常用于:

  • 遮蔽(mask)掉某些无效元素(如 attention mask)
  • 对填充值进行处理
  • 实现数值屏蔽(如用 -inf 屏蔽 softmax 中的无效位置)

1. 函数原型

Tensor.masked_fill(mask, value) → Tensor
参数 说明
mask 一个与原张量 shape 可广播的 布尔类型张量True 的位置将被替换
value 替换值,数值标量(如 0, -1e9, -inf 等)

2. 基本示例

import torch

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

mask = x > 3

# 将大于 3 的位置替换为 0
result = x.masked_fill(mask, 0)

print(result)

输出:

tensor([[1, 2, 3],
        [0, 0, 0]])

3. 示例:填充为负无穷(-inf)用于屏蔽 attention

scores = torch.tensor([[0.1, 0.3, 0.6],
                       [0.4, 0.2, 0.4]])

mask = torch.tensor([[False, True, False],
                     [False, False, True]])

masked_scores = scores.masked_fill(mask, float('-inf'))
print(masked_scores)

输出:

tensor([[ 0.1000,   -inf,  0.6000],
        [ 0.4000,  0.2000,   -inf]])

4. 广播兼容的 mask

x = torch.ones(2, 3, 4)
mask = torch.tensor([True, False, True]).view(1, 3, 1)

x_masked = x.masked_fill(mask, -1)
print(x_masked.shape)  # torch.Size([2, 3, 4])

只对中间维度为 True 的部分填充为 -1,其余不变。


5. 常见应用场景

场景 用法
注意力机制中屏蔽无效位置 scores.masked_fill(mask, -inf)
掩盖 padding 部分的输入 x.masked_fill(pad_mask, 0)
实现 logits 屏蔽 logits.masked_fill(mask, -1e9)
在 loss 中屏蔽无效标签位置 loss.masked_fill(ignore_mask, 0)

6. 与其他方法的对比

方法 功能 是否返回新张量
masked_fill() 替换为某个值
masked_select() 筛选出满足条件的元素,返回 1D 张量
masked_scatter() 将指定位置替换为另一个张量的值 是(更复杂)

7. 注意事项

  • mask 必须为 torch.bool 类型,或者是布尔表达式的结果。
  • value 是标量,不能是张量(使用 masked_scatter 可实现张量赋值)。
  • 会返回一个 新的张量,原张量不变。

8. 总结

特性 说明
功能 将满足条件的位置替换为指定值
输入 mask(布尔张量),value(标量)
返回 新的张量,原张量不变
常用场景 attention mask、padding mask、loss 屏蔽

masked_fill() 是在构建 Transformer、seq2seq、文本处理等深度学习模型中非常常用的函数,它是高效地实现“条件赋值”的利器。