pytorch搭建Lenet网络模型(以CIFAR10为例)

前言

本文介绍利用Pytorch搭建Lenet模型,以cifar10为例,具体流程见博客

https://blog.csdn.net/qq_43542339/article/details/106058752

Lenet

  • 代码
    由于cifar10为三通道图像,故self.conv1in_channels参数为3
class Lenet_torch(nn.Module):
    def __init__(self):
        super(Lenet_torch,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5,stride=1)  #output:28*28*6
        self.maxpool1 = nn.MaxPool2d(2)  #output:14*14*6
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride=1)  #output:10*10*16
        self.maxpool2 = nn.MaxPool2d(2)  #output:5*5*16
        self.fc1 = nn.Linear(in_features=16*5*5,out_features=120)
        self.fc2 = nn.Linear(in_features=120,out_features=84)
        self.fc3 = nn.Linear(in_features=84,out_features=10)  #out_features为类别数

    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.maxpool1(x)

        x = x.view(x.size(0),-1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        return x

  • print(net)查看模型
    在这里插入图片描述
  • 运行成功,说明网络模型搭建成功
    在这里插入图片描述

小结

  • Lenet网络较为简单,搭建起来很容易,注意一下这里不同于VGG,没有利用padding进行填充(当然自己使用的时候完全可以更改)
  • 如有不足还请多多指教

猜你喜欢

转载自blog.csdn.net/qq_43542339/article/details/106134528