Tensorflow2.0之自定义ResNet

1、ResNet 网络结构

残差网络由残差块(ResnetBlock)组成,每一个残差块又是由多个 Residual 构成的。下面以 ResNet18 为例,分析残差网络结构的构建。

1.1 Residual

Residual 的结构为:
在这里插入图片描述上图中卷积层参数的设置:

  • 卷积层1:kernel_size=3x3,padding=‘same’
  • 卷积层2:kernel_size=3x3,padding=‘same’,strides=1
  • 卷积层3:kernel_size=1x1,padding=‘valid’
  • 注1:卷积层的通道数都是相同的,由用户指定。
  • 注2:卷积层1和卷积层3的步长相等,由用户指定。

XY 的过程:

  • 先通过卷积层1,此时的通道数可能发生改变,由于步长不确定,故此时的长宽也会发生改变。
  • 然后经过卷积层2,由于卷积层2的步长确定为1,通道数和卷积层1的相同,所以这一层的输入和输出的 shape 是相同的。

但是,在进行 X + Y 的时候,我们需要让 XYshape 相同,所以此时需要卷积层3的参与。
如果 YshapeX 的不同,则将 X 输入到卷积层3来使其 shape 等于 Yshape ;否则,卷积层3不被需要。

1.2 ResnetBlock

因为在残差网络中,第一个残差块的输入和输出的 shape 是相同的,而在其他残差块中,输出的长和宽是输入的一半,而且输出的通道数是输入的两倍。所以残差网络中的第一个残差块中的所有 Residual 都不需要卷积层3,而其他残差块中的第一个 Residual 都需要卷积层3使 XYshape 相同。所以 ResnetBlock 的构建规则为:
在这里插入图片描述

1.3 ResNet

残差网络的参数变化如图所示:
在这里插入图片描述

  • conv1:卷积层、BN层、ReLU层;
  • conv2_x:池化层、残差块;
  • conv3_x:残差块;
  • conv4_x:残差块;
  • conv5_x:残差块;
  • 全局平均池化层后接上全连接层输出。

2、代码构建残差网络

2.1 Residual

class Residual(tf.keras.Model):
    def __init__(self, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(num_channels, kernel_size=3,
                                            strides=strides, padding='same')
        self.conv2 = tf.keras.layers.Conv2D(num_channels, kernel_size=3,
                                            padding='same')
        if use_1x1conv:
            self.conv3 = tf.keras.layers.Conv2D(num_channels,
                                       kernel_size=1,
                                       strides=strides)
        else:
            self.conv3 = None
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.bn2 = tf.keras.layers.BatchNormalization()
        
    def call(self, x):
        y = tf.nn.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        if self.conv3:
            x = self.conv3(x)
        return tf.nn.relu(y + x)

2.2 ResnetBlock

class ResnetBlock(tf.keras.Model):
    def __init__(self, num_channels, num_residuals, first_block=False):
        super().__init__()
        self.listLayers=[]
        for i in range(num_residuals):
            if i == 0 and not first_block:
                self.listLayers.append(Residual(num_channels,
                                                use_1x1conv=True,
                                                strides=2))
            else:
                self.listLayers.append(Residual(num_channels))
                
    def call(self, x):
        for layer in self.listLayers.layers:
            x = layer(x)
        return x

2.3 ResNet

class ResNet(tf.keras.Model):
    def __init__(self, num_blocks):
        super().__init__()
        self.conv=layers.Conv2D(64, kernel_size=7, strides=2, padding='same')
        self.bn=layers.BatchNormalization()
        self.relu=layers.Activation('relu')
        
        self.mp=layers.MaxPool2D(pool_size=3, strides=2, padding='same')
        self.resnet_block1=ResnetBlock(64,num_blocks[0], first_block=True)
        
        self.resnet_block2=ResnetBlock(128,num_blocks[1])
        
        self.resnet_block3=ResnetBlock(256,num_blocks[2])
        
        self.resnet_block4=ResnetBlock(512,num_blocks[3])
        
        self.gap=layers.GlobalAvgPool2D()
        self.fc=layers.Dense(units=10,activation=tf.keras.activations.softmax)

    def call(self, x):
        x=self.conv(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.mp(x)
        x=self.resnet_block1(x)
        x=self.resnet_block2(x)
        x=self.resnet_block3(x)
        x=self.resnet_block4(x)
        x=self.gap(x)
        x=self.fc(x)
        return x

2.4 网络检验

mynet=ResNet([2,2,2,2])
X = tf.random.uniform(shape=(1,  224, 224 , 1))
for layer in mynet.layers:
    X = layer(X)
    print(layer.name, 'output shape:\t', X.shape)
conv2d_51 output shape:	 (1, 112, 112, 64)
batch_normalization_42 output shape:	 (1, 112, 112, 64)
activation_2 output shape:	 (1, 112, 112, 64)
max_pooling2d_2 output shape:	 (1, 56, 56, 64)
resnet_block_9 output shape:	 (1, 56, 56, 64)
resnet_block_10 output shape:	 (1, 28, 28, 128)
resnet_block_11 output shape:	 (1, 14, 14, 256)
resnet_block_12 output shape:	 (1, 7, 7, 512)
global_average_pooling2d_2 output shape:	 (1, 512)
dense_2 output shape:	 (1, 10)
发布了120 篇原创文章 · 获赞 15 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/qq_36758914/article/details/104939574