【PyTorch】Tensor.where() 方法:根据给定条件在两个张量之间逐元素选择值

Tensor.where() 方法详解(PyTorch)

在 PyTorch 中,Tensor.where()torch.where()张量实例方法版本,用于根据给定条件在两个张量之间逐元素选择值


1. 方法原型

Tensor.where(condition, other) → Tensor
参数 说明
condition 条件张量(BoolTensor),与调用者张量或 other 可广播
other 替代张量,与调用者形状可广播
返回值 按条件选择后生成的新张量(不修改原张量)

2. 功能说明

等价于:

torch.where(condition, self, other)

逻辑含义:

output[i] = self[i] if condition[i] else other[i]

3. 示例讲解

import torch

a = torch.tensor([10, 20, 30])
b = torch.tensor([1, 2, 3])
condition = torch.tensor([True, False, True])

# 使用 Tensor.where()
result = a.where(condition, b)
print(result)

输出:

tensor([10, 2, 30])

解释:

  • 条件为 True 的位置用 a 的值;
  • 条件为 False 的位置用 b 的值。

4. 与 torch.where() 的区别

函数形式 写法 说明
函数式 torch.where(cond, x, y) 推荐用于通用调用
面向对象 x.where(cond, y) 更直观(类似 NumPy)

5. 广播示例

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[10, 20], [30, 40]])
mask = torch.tensor([[True, False], [False, True]])

result = a.where(mask, b)
print(result)

输出:

tensor([[1, 20],
        [30, 4]])

6. 应用场景

场景 示例
softmax 屏蔽无效值 logits = logits.where(mask, float('-inf'))
训练中替代填充值 inputs = inputs.where(valid_mask, padding_value)
条件替换(batch 中处理) cleaned = x.where(x >= 0, torch.tensor(0.0))

7. 总结

特性 说明
功能 在两个张量之间逐元素选择值
等价写法 a.where(cond, b)torch.where(cond, a, b)
要求 所有张量可以广播(broadcasting)
返回 新张量,原张量不变
应用 条件选择、屏蔽、注意力、padding 替换等

Tensor.where() 是 PyTorch 中非常常用的条件操作方法,特别适合在模型训练或推理中进行动态控制和替代。其与 torch.where() 本质相同,选择哪种形式取决于个人编码风格。