Tensor.where()
方法详解(PyTorch)
在 PyTorch 中,Tensor.where()
是 torch.where()
的 张量实例方法版本,用于根据给定条件在两个张量之间逐元素选择值。
1. 方法原型
Tensor.where(condition, other) → Tensor
参数 | 说明 |
---|---|
condition |
条件张量(BoolTensor ),与调用者张量或 other 可广播 |
other |
替代张量,与调用者形状可广播 |
返回值 | 按条件选择后生成的新张量(不修改原张量) |
2. 功能说明
等价于:
torch.where(condition, self, other)
逻辑含义:
output[i] = self[i] if condition[i] else other[i]
3. 示例讲解
import torch
a = torch.tensor([10, 20, 30])
b = torch.tensor([1, 2, 3])
condition = torch.tensor([True, False, True])
# 使用 Tensor.where()
result = a.where(condition, b)
print(result)
输出:
tensor([10, 2, 30])
解释:
- 条件为
True
的位置用a
的值; - 条件为
False
的位置用b
的值。
4. 与 torch.where()
的区别
函数形式 | 写法 | 说明 |
---|---|---|
函数式 | torch.where(cond, x, y) |
推荐用于通用调用 |
面向对象 | x.where(cond, y) |
更直观(类似 NumPy) |
5. 广播示例
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
mask = torch.tensor([[True, False], [False, True]])
result = a.where(mask, b)
print(result)
输出:
tensor([[1, 20],
[30, 4]])
6. 应用场景
场景 | 示例 |
---|---|
softmax 屏蔽无效值 | logits = logits.where(mask, float('-inf')) |
训练中替代填充值 | inputs = inputs.where(valid_mask, padding_value) |
条件替换(batch 中处理) | cleaned = x.where(x >= 0, torch.tensor(0.0)) |
7. 总结
特性 | 说明 |
---|---|
功能 | 在两个张量之间逐元素选择值 |
等价写法 | a.where(cond, b) ≈ torch.where(cond, a, b) |
要求 | 所有张量可以广播(broadcasting) |
返回 | 新张量,原张量不变 |
应用 | 条件选择、屏蔽、注意力、padding 替换等 |
Tensor.where()
是 PyTorch 中非常常用的条件操作方法,特别适合在模型训练或推理中进行动态控制和替代。其与 torch.where()
本质相同,选择哪种形式取决于个人编码风格。