当torch.where只输入一个参数时: a = torch.where(b_bool)[0]
在这个语句中a会提取出b_bool为true的角标,但会返回成一个元组。
如果b_bool是tensor类型,那返回的元组含有的唯一元素就是所有true角标的tensor
所有为了a直接得到一个tensor,我们用【0】做索引( 但前提是输入的b_bool只有一维)
若为b_bool二维,那得到的返回元组也将是二维,其中包含着的仍是true的角标,但角标的第一维在元组的第一维,角标的第二维在元组的第二维。
eg:
x = torch.tensor([[-1, 2, 0], [0, -3, 4]])
result = torch.where(x > -1) # 返回一个掩码张量
print(result)
# Output:
# (tensor([0, 0, 1, 1]), tensor([1, 2, 0, 2]))
这里和tf.where的区别:torch.where的输出形状是【2,4】,
即【num1_dim1,num2_dim1,.......】【num1_dim2,num2_dim2,.......】
而如果是tf.where输出形状将是【4,2】,
即得到四个true元素,每个元素的索引为【dim1,dim2】
另:
在 PyTorch 中,torch.where()
函数返回一个张量或元组,具体输出是由函数输出参数的数量和类型所决定的。
在下面的示例中,input_tensor
张量中大于2的元素被赋予了新的值:
import torch
input_tensor = torch.tensor([1, 2, 3, 4, 5])
condition = input_tensor > 2
output_tensor = torch.where(condition, torch.tensor(10), input_tensor)
print(output_tensor) # tensor([1, 2, 10, 10, 10])
例,在下面的示例中,我们用另一个张量 input_tensor_2
的值替换了 input_tensor
中大于2的元素:
import torch
input_tensor = torch.tensor([1, 2, 3, 4, 5])
input_tensor_2 = torch.tensor([1, 2, 3, 4, 5]) * -1
condition = input_tensor > 2
output_tensor = torch.where(condition, input_tensor_2, input_tensor)
print(output_tensor) # tensor([ 1, 2, -3, -4, -5])
--------------------------------------------------------------------------------------------------------------------------
tf.where与torch.where的几个不同点:
1.输入参数顺序不同
在tf.where
中,条件张量必须作为第一个参数,然后是True
分支和False
分支。而在torch.where
中,条件张量是最后一个参数。因此,为了在两个库中使用相同的条件和分支,我们需要在调用时调整参数的顺序。
2.自动广播的可用性不同
tf.where
和torch.where
都支持在条件张量与分支张量形状不同的情况下进行自动广播(Broadcasting)。但是在这方面,两个函数的行为是有所不同的。
在tf.where
中,自动广播使用的是NumPy风格的广播规则,也就是只有两个张量的最后一维匹配时才会发生广播。例如,如果条件张量和分支张量分别具有形状(3, 2)
和(2,)
,则在执行条件选择时,第二个张量将被广播为(3, 2)
的形状。
在torch.where
中,自动广播使用的是PyTorch风格的广播规则,该规则较为宽松。其基本思想是,如果两个张量的形状能够通过将其中一个张量插入1维来匹配,则它们可以被广播。例如,如果条件张量和分支张量分别具有形状(3, 2)
和(2,)
,则在执行条件选择时,第二个张量将被广播为(1, 2)
的形状,然后再沿着第一维复制3次以匹配条件张量的形状。
3.返回值的类型不同
在tf.where
中,返回的张量类型由True
分支和False
分支中较高的dtype决定。例如,如果True
分支和False
分支分别具有float32
和int32
类型,则返回的张量类型将为float32
。
在torch.where
中,返回的张量类型由条件张量的dtype决定。例如,如果条件张量的类型为float32
,则返回的张量类型也将为float32
。
需要注意的是,在实际使用中,这些差异可能并不总是会造成问题。如果我们按照正确的参数顺序传递相同的张量并使用正确的广播方式,那么这两个函数应该可以产生相同的输出。