import torch
>>> import numpy
>>> a=torch.rand(2,2)
>>> a
tensor([[6.0998e-01, 1.1539e-04],
[6.8827e-01, 4.0862e-01]])
>>> b=torch.rand(2,2)
>>> b
tensor([[0.5579, 0.8235],
[0.0321, 0.7925]])
>>> a=torch.rand(2,2)
>>> a
tensor([[0.5463, 0.1074],
[0.8725, 0.4264]])
>>> c=torch.rand(2,2)
>>> c
tensor([[0.9275, 0.0640],
[0.4547, 0.6653]])
>>> d=torch.where(c,a,b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected condition to have ScalarType Byte, but got ScalarType Float
>>> d=torch.where(c>0.5,a,b)
>>> d
tensor([[0.5463, 0.8235],
[0.0321, 0.4264]])
类似于条件选择