nn.Sequential
und nn.ModuleList()
sind PyTorch
zwei verschiedene Möglichkeiten, Submodule in einem neuronalen Netzwerkmodell zu verwalten.
nn.Sequential
ist eine Containerklasse zum Erstellen sequenzieller Modelle. Es ermöglicht das Hinzufügen einer Reihe von Submodulen in einer bestimmten Reihenfolge und deren Verkettung, um eine sequentielle Netzwerkstruktur zu bilden. nn.Sequential
Es kann die Definition des Modells und das Schreiben der Vorwärtsausbreitung vereinfachen, insbesondere für einfache Netzwerkstrukturen ohne komplexe Steuerungsprozesse. Durch das Hinzufügen nn.Sequential
von werden diese Submodule automatisch in der Reihenfolge, in der sie hinzugefügt wurden, zu einem Gesamtmodell zusammengefügt. nn.Sequential
Wenn die Methode aufgerufen wird forward
, durchlaufen die Eingabedaten jedes Untermodul in der Reihenfolge, in der sie hinzugefügt wurden, und ermöglichen so die Vorwärtspropagierung des gesamten Modells.
Das Beispiel verwendet nn.Sequential
zum Erstellen eines einfachen Modells:
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 10)
)
input_tensor = torch.randn(32, 10)
output_tensor = model(input_tensor)
In diesem Beispiel nn.Sequential
definieren . Das sequentielle Modell besteht aus drei Submodulen: einer linearen Schicht, einer ReLU
Aktivierungsfunktion und einer weiteren linearen Schicht. Wenn wir die forward
Methode , input_tensor
durchlaufen die Eingabedaten jedes Submodul in der Reihenfolge, in der sie hinzugefügt wurden, und die Ausgabedaten werden generiert output_tensor
.
Im Gegensatz dazu nn.ModuleList()
handelt es sich um einen Python
listenartigen Container zum Speichern und Verwalten einer beliebigen Anzahl von Submodulen. nn.Sequential
Anders als bei nn.ModuleList()
werden die Submodule nicht automatisch verknüpft, sondern als Liste gespeichert. Daher müssen wir bei der nn.ModuleList()
Definition eines Modells die Verbindungsbeziehung zwischen den Untermodulen selbst definieren. Dies macht nn.ModuleList()
es flexibler und eignet sich für Netzwerkstrukturen, die komplexe Steuerungsprozesse aufweisen oder benutzerdefinierte Verbindungsmethoden erfordern.
Das Beispiel verwendet nn.ModuleList()
zum Erstellen eines einfachen Modells:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.module_list = nn.ModuleList([
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 10)
])
def forward(self, x):
for module in self.module_list:
x = module(x)
return x
model = MyModel()
input_tensor = torch.randn(32, 10)
output_tensor = model(input_tensor)
In diesem Beispiel definieren wir eine benutzerdefinierte Modellklasse MyModel
, die nn.ModuleList()
zum : eine lineare Ebene, eine ReLU
Aktivierungsfunktion und eine weitere lineare Ebene. In der forward
Methode iterieren wir durch module_list
die Submodule in , x
übergeben und erhalten die endgültige Ausgabe.
Somit besteht der Unterschied zwischen undnn.Sequential
in der Fähigkeit, Submodule automatisch zu verdrahten. Verbinden Sie Submodule automatisch in der Reihenfolge, in der sie hinzugefügt wurden, geeignet für einfache sequentielle Modelle. Allerdings müssen Sie den Verbindungsmodus zwischen Submodulen manuell definieren, was für Modelle mit komplexen Steuerungsprozessen oder benutzerdefinierten Verbindungen geeignet ist.nn.ModuleList()
nn.Sequential
nn.ModuleList()
Darüber hinaus nn.Sequential
wird eine sauberere Syntax zum Definieren von Modellen bereitgestellt, da es möglich ist, Modelle direkt durch Übergabe einer Liste von Submodulen zu erstellen. nn.ModuleList()
Stattdessen müssen Sie Submodule in der Modellklasse explizit definieren und initialisieren.
nn.Sequential
und nn.ModuleList()
sind nn.Module
beide Unterklassen von , sodass sie beide als Attribute des Modells registriert und verwaltet werden können.