Wenn Torch.where nur einen Parameter eingibt: a = Torch.where(b_bool)[0]
In dieser Anweisung extrahiert a den Index, bei dem b_bool wahr ist, er wird jedoch als Tupel zurückgegeben.
Wenn b_bool vom Typ Tensor ist, sind die einzigen im zurückgegebenen Tupel enthaltenen Elemente alle Tensoren mit echter Bezeichnung.
Um direkt einen Tensor für a zu erhalten, verwenden wir [0] als Index (aber nur, wenn die Eingabe b_bool nur eine Dimension hat).
Wenn b_bool zweidimensional ist, ist das zurückgegebene Tupel ebenfalls zweidimensional und enthält den Index, der immer noch wahr ist, aber die erste Dimension des Index befindet sich in der ersten Dimension des Tupels und die zweite Dimension des Index ist im Element. Die zweite Dimension der Gruppe.
z.B:
x = torch.tensor([[-1, 2, 0], [0, -3, 4]])
result = torch.where(x > -1) # 返回一个掩码张量
print(result)
# Output:
# (tensor([0, 0, 1, 1]), tensor([1, 2, 0, 2]))
Der Unterschied zwischen hier und tf.where: Die Ausgabeform von Torch.where ist [2,4],
also [num1_dim1, num2_dim1,...] [num1_dim2, num2_dim2,...]
und if ist tf.where Die Ausgabeform ist [4,2],
dh es werden vier echte Elemente erhalten, und der Index jedes Elements ist [dim1, dim2].
Andere:
In PyTorch torch.where()
gibt eine Funktion einen Tensor oder ein Tupel zurück und die spezifische Ausgabe wird durch die Anzahl und den Typ der Ausgabeparameter der Funktion bestimmt.
Im folgenden Beispiel input_tensor
werden Elementen im Tensor größer als 2 neue Werte zugewiesen:
import torch
input_tensor = torch.tensor([1, 2, 3, 4, 5])
condition = input_tensor > 2
output_tensor = torch.where(condition, torch.tensor(10), input_tensor)
print(output_tensor) # tensor([1, 2, 10, 10, 10])
Im folgenden Beispiel input_tensor_2
ersetzen wir beispielsweise input_tensor
Elemente größer als 2 in einem anderen Tensor durch den Wert eines anderen Tensors:
import torch
input_tensor = torch.tensor([1, 2, 3, 4, 5])
input_tensor_2 = torch.tensor([1, 2, 3, 4, 5]) * -1
condition = input_tensor > 2
output_tensor = torch.where(condition, input_tensor_2, input_tensor)
print(output_tensor) # tensor([ 1, 2, -3, -4, -5])
-------------------------------------------------- -------------------------------------------------- ----------------------
Mehrere Unterschiede zwischen tf.where und Torch.where:
1. Die Reihenfolge der Eingabeparameter ist unterschiedlich
In tf.where
muss der Bedingungstensor als erstes Argument angegeben werden, gefolgt von True
Verzweigungen und False
Verzweigungen. Und in torch.where
ist der Bedingungstensor der letzte Parameter. Um in beiden Bibliotheken dieselben Bedingungen und Zweige verwenden zu können, müssen wir daher die Reihenfolge der Parameter beim Aufruf anpassen.
2. Die Verfügbarkeit automatischer Übertragungen variiert
tf.where
und torch.where
beide unterstützen die automatische Übertragung (Broadcasting), wenn der Bedingungstensor und der Verzweigungstensor unterschiedliche Formen haben. In dieser Hinsicht ist das Verhalten der beiden Funktionen jedoch unterschiedlich.
Bei der tf.where
automatischen Übertragung werden Übertragungsregeln im NumPy-Stil verwendet, d. h. die Übertragung erfolgt nur, wenn die letzten Dimensionen der beiden Tensoren übereinstimmen. Wenn beispielsweise der Bedingungstensor und der Zweigtensor jeweils eine Form haben (3, 2)
, (2,)
wird bei der Durchführung der bedingten Auswahl der zweite Tensor als Form übertragen (3, 2)
.
In torch.where
verwendet die automatische Übertragung Übertragungsregeln im PyTorch-Stil, die lockerer sind. Die Grundidee besteht darin, dass zwei Tensoren übertragen werden können, wenn ihre Formen durch Interpolation eines der Tensoren in eine Dimension angepasst werden können. Wenn beispielsweise der Bedingungstensor und der Zweigtensor jeweils eine Form haben (3, 2)
, (2,)
wird bei der Bedingungsauswahl der zweite Tensor als (1, 2)
Form übertragen und dann dreimal entlang der ersten Dimension kopiert, um der Form des Bedingungstensors zu entsprechen.
3. Die Arten von Rückgabewerten sind unterschiedlich
In tf.where
wird der zurückgegebene Tensortyp durch True
den Zweig und False
den höheren dtype im Zweig bestimmt. Wenn beispielsweise True
branch und False
branch jeweils die Typen haben float32
, int32
ist der zurückgegebene Tensortyp float32
.
In torch.where
wird der zurückgegebene Tensortyp durch den dtype des Bedingungstensors bestimmt. Wenn der Bedingungstensor beispielsweise vom Typ ist float32
, wird auch der zurückgegebene Tensortyp sein float32
.
Es ist wichtig zu beachten, dass diese Unterschiede bei der tatsächlichen Verwendung möglicherweise nicht immer zu Problemen führen. Wenn wir dieselben Tensoren in der richtigen Argumentreihenfolge übergeben und das richtige Broadcasting verwenden, sollten beide Funktionen die gleiche Ausgabe erzeugen.