目录
解释
-
nn.ModuleList
:nn.ModuleList
是一个容器,用于存储多个nn.Module
。与nn.Sequential
不同,nn.ModuleList
不会自动执行层的连接,因此需要在forward
方法中手动实现前向传播。
-
特性:
nn.ModuleList
适用于需要灵活控制网络层连接的场景,例如当网络层不是简单的线性堆叠时(如存在条件分支或跳跃连接)。- 它与
nn.Sequential
的主要区别在于nn.ModuleList
不隐式地处理层的顺序,前向传播的实现需要手动编写。
-
前向传播:
- 在
forward
方法中,依次通过self.layers
中的每个层对输入数据x
进行处理。每个层的输出会作为下一个层的输入。
- 在
总结
OrderedDict
: 适用于需要对层进行命名,并明确顺序的场景。nn.Sequential
: 适用于网络层简单的线性堆叠。nn.ModuleList
: 提供更大的灵活性,适用于复杂的网络结构。
每种方式都有其适用场景,选择适当的方法可以使模型设计更加直观和灵活。
示例
在nn.Sequential中嵌套OrderedDict组织网络,以对层进行命名
import torch
import torch.nn as nn
from collections import OrderedDict
class OrderedDictCNN(nn.Module):
def __init__(self):
super(OrderedDictCNN, self).__init__()
# 使用 OrderedDict 定义网络层
self.model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)), # 初始卷积层
('bn1', nn.BatchNorm2d(64)),
('relu1', nn.ReLU(inplace=True)),
('maxpool1', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
('conv2', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)), # 特征提取层
('bn2', nn.BatchNorm2d(128)),
('relu2', nn.ReLU(inplace=True)),
('maxpool2', nn.MaxPool2d(kernel_size=2, stride=2, padding=0)),
('flatten', nn.Flatten()), # 展平层
('fc1', nn.Linear(128 * 112 * 112, 1000)), # 全连接层
('relu3', nn.ReLU(inplace=True)),
('fc2', nn.Linear(1000, 10)) # 输出层
]))
def forward(self, x):
return self.model(x)
使用多个nn.Sequential组织网络
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 初始卷积层
self.stem = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
# 特征提取层
self.feature_extraction = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
# 全连接层
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 112 * 112, 1000),
nn.ReLU(inplace=True),
nn.Linear(1000, 10)
)
def forward(self, x):
x = self.stem(x)
x = self.feature_extraction(x)
x = self.fc(x)
return x
使用单个nn.Sequential组织网络
import torch
import torch.nn as nn
class SequentialCNN(nn.Module):
def __init__(self):
super(SequentialCNN, self).__init__()
# 使用 nn.Sequential 定义网络层
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3), # 初始卷积层
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), # 特征提取层
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Flatten(), # 展平层
nn.Linear(128 * 112 * 112, 1000), # 全连接层
nn.ReLU(inplace=True),
nn.Linear(1000, 10) # 输出层
)
def forward(self, x):
return self.model(x)
使用nn.ModuleList组织
ModuleList静态组织示例
import torch
import torch.nn as nn
class ModuleListCNN(nn.Module):
def __init__(self):
super(ModuleListCNN, self).__init__()
# 使用 nn.ModuleList 定义网络层
self.layers = nn.ModuleList([
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3), # 初始卷积层
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), # 特征提取层
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Flatten(), # 展平层
nn.Linear(128 * 112 * 112, 1000), # 全连接层
nn.ReLU(inplace=True),
nn.Linear(1000, 10) # 输出层
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# 实例化模型
model = ModuleListCNN()
print(model)
ModuleList动态组织示例
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)