卷积神经网络:Residual Net(残差网络)的定义

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module): #残差块:将网络层进行组合,使得z=x+y(x与y同型)
    def __init__(self,channels):#输入通道数定为未知量,当实例化模型时可以调用
        super(ResidualBlock,self).__init__()
        self.channels=channels
        self.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1) #输入输出的通道数不变
        self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1) #输入输出的通道数不变
    def forward(self,x):
        y=F.relu(self.conv1(x))
        y=self.conv2(y)            #经过两次卷积层+激活函数,第二次激活函数前与最先的x相加
        return F.relu(x+y)         #先求和后激活

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,16,kernel_size=5)
        self.conv2=nn.Conv2d(16,32,kernel_size=5)
        self.mp=nn.MaxPool2d(2)

        self.rblock1=ResidualBlock(16)
        self.rblock2=ResidualBlock(32)

        self.fc=nn.Linear(512,20)

    def forward(self,x):
        in_size=x.size(0)
        x=self.mp(F.relu(self.conv1(x)))
        x=self.rblock1(x)
        x=self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)
        x=x.view(in_size,-1)
        x=self.fc(x)
        return x

猜你喜欢

转载自blog.csdn.net/qq_21686871/article/details/114380547
今日推荐