Tensor.masked_fill()
方法详解
在 PyTorch 中,masked_fill()
是一个用于 根据布尔掩码(mask)将张量中部分元素替换为指定值 的方法。它常用于:
- 遮蔽(mask)掉某些无效元素(如 attention mask)
- 对填充值进行处理
- 实现数值屏蔽(如用
-inf
屏蔽 softmax 中的无效位置)
1. 函数原型
Tensor.masked_fill(mask, value) → Tensor
参数 | 说明 |
---|---|
mask |
一个与原张量 shape 可广播的 布尔类型张量,True 的位置将被替换 |
value |
替换值,数值标量(如 0 , -1e9 , -inf 等) |
2. 基本示例
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
mask = x > 3
# 将大于 3 的位置替换为 0
result = x.masked_fill(mask, 0)
print(result)
输出:
tensor([[1, 2, 3],
[0, 0, 0]])
3. 示例:填充为负无穷(-inf)用于屏蔽 attention
scores = torch.tensor([[0.1, 0.3, 0.6],
[0.4, 0.2, 0.4]])
mask = torch.tensor([[False, True, False],
[False, False, True]])
masked_scores = scores.masked_fill(mask, float('-inf'))
print(masked_scores)
输出:
tensor([[ 0.1000, -inf, 0.6000],
[ 0.4000, 0.2000, -inf]])
4. 广播兼容的 mask
x = torch.ones(2, 3, 4)
mask = torch.tensor([True, False, True]).view(1, 3, 1)
x_masked = x.masked_fill(mask, -1)
print(x_masked.shape) # torch.Size([2, 3, 4])
只对中间维度为 True
的部分填充为 -1
,其余不变。
5. 常见应用场景
场景 | 用法 |
---|---|
注意力机制中屏蔽无效位置 | scores.masked_fill(mask, -inf) |
掩盖 padding 部分的输入 | x.masked_fill(pad_mask, 0) |
实现 logits 屏蔽 | logits.masked_fill(mask, -1e9) |
在 loss 中屏蔽无效标签位置 | loss.masked_fill(ignore_mask, 0) |
6. 与其他方法的对比
方法 | 功能 | 是否返回新张量 |
---|---|---|
masked_fill() |
替换为某个值 | 是 |
masked_select() |
筛选出满足条件的元素,返回 1D 张量 | 是 |
masked_scatter() |
将指定位置替换为另一个张量的值 | 是(更复杂) |
7. 注意事项
mask
必须为torch.bool
类型,或者是布尔表达式的结果。value
是标量,不能是张量(使用masked_scatter
可实现张量赋值)。- 会返回一个 新的张量,原张量不变。
8. 总结
特性 | 说明 |
---|---|
功能 | 将满足条件的位置替换为指定值 |
输入 | mask(布尔张量),value(标量) |
返回 | 新的张量,原张量不变 |
常用场景 | attention mask、padding mask、loss 屏蔽 |
masked_fill()
是在构建 Transformer、seq2seq、文本处理等深度学习模型中非常常用的函数,它是高效地实现“条件赋值”的利器。