pytorch中常见的模型4种组织方式 nn.Sequential(OrderedDict)

解释

  1. nn.ModuleList:

    • nn.ModuleList 是一个容器,用于存储多个 nn.Module。与 nn.Sequential 不同,nn.ModuleList 不会自动执行层的连接,因此需要在 forward 方法中手动实现前向传播。
  2. 特性:

    • nn.ModuleList 适用于需要灵活控制网络层连接的场景,例如当网络层不是简单的线性堆叠时(如存在条件分支或跳跃连接)。
    • 它与 nn.Sequential 的主要区别在于 nn.ModuleList 不隐式地处理层的顺序,前向传播的实现需要手动编写。
  3. 前向传播:

    • forward 方法中,依次通过 self.layers 中的每个层对输入数据 x 进行处理。每个层的输出会作为下一个层的输入。

总结

  • OrderedDict: 适用于需要对层进行命名,并明确顺序的场景。
  • nn.Sequential: 适用于网络层简单的线性堆叠。
  • nn.ModuleList: 提供更大的灵活性,适用于复杂的网络结构。

每种方式都有其适用场景,选择适当的方法可以使模型设计更加直观和灵活。

示例

在nn.Sequential中嵌套OrderedDict组织网络,以对层进行命名

import torch
import torch.nn as nn
from collections import OrderedDict

class OrderedDictCNN(nn.Module):
    def __init__(self):
        super(OrderedDictCNN, self).__init__()
        # 使用 OrderedDict 定义网络层
        self.model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)),  # 初始卷积层
            ('bn1', nn.BatchNorm2d(64)),
            ('relu1', nn.ReLU(inplace=True)),
            ('maxpool1', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
            
            ('conv2', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)),  # 特征提取层
            ('bn2', nn.BatchNorm2d(128)),
            ('relu2', nn.ReLU(inplace=True)),
            ('maxpool2', nn.MaxPool2d(kernel_size=2, stride=2, padding=0)),
            
            ('flatten', nn.Flatten()),  # 展平层
            ('fc1', nn.Linear(128 * 112 * 112, 1000)),  # 全连接层
            ('relu3', nn.ReLU(inplace=True)),
            ('fc2', nn.Linear(1000, 10))  # 输出层
        ]))
    
    def forward(self, x):
        return self.model(x)

使用多个nn.Sequential组织网络

import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 初始卷积层
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # 特征提取层
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        )
        # 全连接层
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 112 * 112, 1000),
            nn.ReLU(inplace=True),
            nn.Linear(1000, 10)
        )
    
    def forward(self, x):
        x = self.stem(x)
        x = self.feature_extraction(x)
        x = self.fc(x)
        return x

使用单个nn.Sequential组织网络

import torch
import torch.nn as nn

class SequentialCNN(nn.Module):
    def __init__(self):
        super(SequentialCNN, self).__init__()
        # 使用 nn.Sequential 定义网络层
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),  # 初始卷积层
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),  # 特征提取层
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            
            nn.Flatten(),  # 展平层
            nn.Linear(128 * 112 * 112, 1000),  # 全连接层
            nn.ReLU(inplace=True),
            nn.Linear(1000, 10)  # 输出层
        )
    
    def forward(self, x):
        return self.model(x)

使用nn.ModuleList组织

ModuleList静态组织示例

import torch
import torch.nn as nn

class ModuleListCNN(nn.Module):
    def __init__(self):
        super(ModuleListCNN, self).__init__()
        # 使用 nn.ModuleList 定义网络层
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),  # 初始卷积层
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),  # 特征提取层
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            
            nn.Flatten(),  # 展平层
            nn.Linear(128 * 112 * 112, 1000),  # 全连接层
            nn.ReLU(inplace=True),
            nn.Linear(1000, 10)  # 输出层
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# 实例化模型
model = ModuleListCNN()
print(model)

ModuleList动态组织示例

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

猜你喜欢

转载自blog.csdn.net/qq_37293230/article/details/140630874