Torch 论文复现:结构重参数化 RepVGGBlock

在 ShuffleNet v2 中提出了轻量化网络的 4 大设计准则:

  • 输入输出通道相同时,MAC 最小
  • FLOPs 相同时,分组数过大的分组卷积会增加 MAC
  • 碎片化操作 (多分支结构) 对并行加速不友好
  • 逐元素操作带来的内存和耗时不可忽略

近年来,卷积神经网络的结构已经变得越来越复杂;得益于多分支结构良好的收敛能力,多分支结构越来越流行

但是,使用多分支结构的时候,一方面无法有效地利用并行加速,另一方面增加了 MAC

ece5470014b045d2b2201386f4fe8640.png

为了使简单结构也能达到与多分支结构相当的精度,在训练 RepVGG 时使用多分支结构 (3×3 卷积 + 1×1 卷积 + 恒等映射),以借助其良好的收敛能力;在推理、部署时利用重参数化技术将多分支结构转化为单路结构,以借助简单结构极致的速度

baaf2b1eed55434c87ed85d6be574261.png

重参数化

训练所使用的多分支结构中,每一个分支中均有一个 BN 层

BN 层有四个运算时使用的参数:mean、var、weight、bias,对输入 x 执行以下变换:

gif.latex?BN%28x%29%3Dweight%20%5Ccdot%20%5Cfrac%7Bx-mean%7D%7B%5Csqrt%7Bvar%7D%7D+bias

转化为 gif.latex?BN%28x%29%20%3D%20w_%7Bbn%7D%20%5Ccdot%20x%20+b_%7Bbn%7D 的形式时:

gif.latex?w_%7Bbn%7D%3D%5Cfrac%7Bweight%7D%7B%5Csqrt%7Bvar%7D%7D%2C%5C%20b_%7Bbn%7D%3Dbias-%5Cfrac%7Bweight%5Ccdot%20mean%7D%7B%5Csqrt%7Bvar%7D%7D

import torch
from torch import nn


class BatchNorm(nn.BatchNorm2d):

    def unpack(self, detach=False):
        mean, bias = self.running_mean, self.bias
        std = (self.running_var + self.eps).float().sqrt().type_as(mean)
        weight = self.weight / std
        eq_param = weight, bias - weight * mean
        return tuple(map(lambda x: x.data, eq_param)) if detach else eq_param


bn = BatchNorm(8).eval()
# 初始化随机参数
bn.running_mean.data, bn.running_var.data, bn.weight.data, bn.bias.data = torch.rand([4, 8])

image = torch.rand([1, 8, 1, 1])
print(bn(image).view(-1))
# 将 BN 的参数转化为 w, b 形式
weight, bias = bn.unpack()
print(image.view(-1) * weight + bias)

因为 BN 层会拟合每一个通道的偏置,所以将卷积层和 BN 层连接在一起使用时,卷积层不使用偏置,其运算可以表示为:

gif.latex?Conv%28x%29%3Dw_%7Bc%7D*x

gif.latex?BN%28Conv%28x%29%29%3Dw_%7Bbn%7Dw_%7Bc%7D*x+b_%7Bbn%7D

可见,卷积层和 BN 层可以等价于一个带偏置的卷积层

8f57be63ffb34c4d97ef57f4cfea131b.png

而恒等映射亦可等价于 1×1 卷积:

  • 对于 nn.Conv2d(c1, c2, kernel_size=1),其参数的 shape 为 [c2, c1, 1, 1] —— 可看作 [c2, c1] 的线性层,以执行各个像素点的通道变换 (参考:Torch 二维多通道卷积运算方式)
  • 当 c1 = c2、且这个线性层为单位阵时,等价于恒等映射

1×1 卷积又可通过填充 0 表示成 3×3 卷积,所以该多分支结构的计算可表示为:

gif.latex?BN_%7B3%20%5Ctimes%203%7D%28Conv_%7B3%20%5Ctimes%203%7D%28x%29%29%3Dw_3*x+b_3

gif.latex?BN_%7B1%20%5Ctimes%201%7D%28Conv_%7B1%20%5Ctimes%201%7D%28x%29%29%3Dw_1*x+b_1

gif.latex?BN_%7Bid%7D%28Conv_%7Bid%7D%28x%29%29%3Dw_o*x+b_0

gif.latex?y%3D%28w_3+w_1+w_0%29*x+%28b_3+b_1+b_0%29

从而可以等价成一个新的 3×3 卷积 (该结论亦可推广到分组卷积、5×5 卷积)

在 NVIDIA 1080Ti 上进行速度测试,以 [32, 2048, 56, 56] 的图像输入卷积核得到同通道同尺寸的输出,3×3 卷积每秒浮点运算量最多

e6443a4de31d464aa3b924aa2db7b12d.png

结构复现

参考代码:https://github.com/DingXiaoH/RepVGG

我对论文中的源代码进行了重构,目的是增强其可读性、易用性 (将重参数化、L2 范数的计算均写入类方法,可方便地操作集成模型)

同时,我将重参数化技术迁移到了简单的 CBS 模块中 (Conv - BN - SiLU),封装成 Conv 类。Conv 的类方法 re_param 可以将集成模型中的所有卷积层和 BN 层进行合并

在此基础上,我又为 RepConv 编写了类方法 merge,用于多分支结构的合并。经过验证,该合并方法适用于所有情况,包括分组卷积

from collections import OrderedDict
from typing import Optional

import torch
from torch import nn


class BatchNorm(nn.BatchNorm2d):

    def unpack(self, detach=False):
        mean, bias = self.running_mean, self.bias
        std = (self.running_var + self.eps).float().sqrt().type_as(mean)
        weight = self.weight / std
        eq_param = weight, bias - weight * mean
        return tuple(map(lambda x: x.data, eq_param)) if detach else eq_param


class Conv(nn.Module):
    ''' Conv - BN - Act'''
    deploy = property(fget=lambda self: isinstance(self.conv, nn.Conv2d))

    def __init__(self, c1, c2, k=3, s=1, g=1,
                 act: Optional[nn.Module] = nn.SiLU, deploy=False):
        super(Conv, self).__init__()
        assert k & 1, 'The convolution kernel size must be odd'
        self._config = dict(
            in_channels=c1, out_channels=c2, kernel_size=k,
            stride=s, padding=k // 2, groups=g
        )
        self.conv = nn.Sequential(OrderedDict(
            conv=nn.Conv2d(**self._config, bias=False),
            bn=BatchNorm(c2)
        )) if not deploy else nn.Conv2d(**self._config, bias=True)
        self.act = act() if act else nn.Identity()

    def forward(self, x):
        return self.act(self.conv(x))

    @classmethod
    def re_param(cls, model: nn.Module):
        for m in filter(lambda m: isinstance(m, cls) and not m.deploy, model.modules()):
            kernel = m.conv.conv.weight.data
            bn_w, bn_b = m.conv.bn.unpack(detach=True)
            # 合并 nn.Conv 与 BatchNorm
            m.conv = nn.Conv2d(**m._config, bias=True)
            m.conv.weight.data, m.conv.bias.data = kernel * bn_w[:, None, None, None], bn_b


class RepConv(nn.Module):
    ''' RepVGGBlock
        identity: 使用恒等映射分支
        conv_1x1: 使用 1×1 卷积分支
        deploy: 使用部署结构'''
    deploy = property(fget=lambda self: isinstance(self.conv_main, nn.Conv2d))

    def __init__(self, c1, c2, k=3, s=1, g=1,
                 act: Optional[nn.Module] = nn.SiLU,
                 identity=True, conv_1x1=True, deploy=False):
        super(RepConv, self).__init__()
        assert k & 1, 'The convolution kernel size must be odd'
        self._center = k // 2
        if deploy:
            self.conv_main = nn.Conv2d(in_channels=c1, out_channels=c2, kernel_size=k,
                                       stride=s, padding=k // 2, groups=g, bias=True)
            self.conv_1x1, self.identity = [None] * 2
        else:
            self.conv_main = Conv(c1, c2, k, s, g, act=None, deploy=False)
            self.conv_1x1 = Conv(c1, c2, 1, s, g, act=None, deploy=False) if conv_1x1 and k != 1 else None
            self.identity = BatchNorm(c2) if identity and c1 == c2 and s == 1 else None
            assert self.conv_1x1 or self.identity, f'This module can be replaced by {Conv}'
        self.act = act() if act else nn.Identity()

    def forward(self, x):
        y = self.conv_main(x)
        for attr in filter(lambda x: getattr(self, x),
                           ['conv_1x1', 'identity']):
            y += getattr(self, attr)(x)
        return self.act(y)

    @classmethod
    def l2_loss(cls, model):
        get_bn_w = lambda x: x.bn.unpack(detach=True)[0][:, None]
        l2_loss = 0
        for m in filter(lambda m: isinstance(m, cls) and not m.deploy, model.modules()):
            conv_main = m.conv_main.conv
            kernel = conv_main.conv.weight
            l2_loss += (kernel ** 2).sum()
            # 堆叠 1×1 卷积分支, 计算中心点的 l2
            if m.conv_1x1:
                conv_1x1 = m.conv_1x1.conv
                # 预处理分支信息
                kernel, bn_main_w = kernel[..., m._center, m._center], get_bn_w(conv_main)
                kernel_1x1, bn_1x1_w = conv_1x1.conv.weight[..., 0, 0], get_bn_w(conv_1x1)
                # 聚合权值, 计算 l2
                l2_center = kernel * bn_main_w + kernel_1x1 * bn_1x1_w
                l2_center = l2_center ** 2 / (bn_main_w ** 2 + bn_1x1_w ** 2)
                # 防止中心点的权值再叠加, 以负值表示消除
                l2_loss += l2_center.sum() - (kernel ** 2).sum()
        return l2_loss

    @classmethod
    def merge(cls, model: nn.Module):
        Conv.re_param(model)
        # 查询模型的所有子模型, 对 RepConv 进行合并
        for m in filter(lambda m: isinstance(m, cls) and not m.deploy, model.modules()):
            conv_main = m.conv_main.conv
            kernel_w, kernel_b = conv_main.weight.data, conv_main.bias.data
            # 转换 1×1 卷积分支: nn.Conv2d
            if m.conv_1x1:
                conv_1x1 = m.conv_1x1.conv
                kernel_w[..., m._center, m._center] += conv_1x1.weight[..., 0, 0].data
                kernel_b += conv_1x1.bias.data
            # 转换恒等映射分支: BatchNorm
            if m.identity:
                g, device = conv_main.groups, kernel_w.device
                kernel_id = torch.eye(kernel_w.shape[1]).repeat(g, 1).to(device)
                bn_w, bn_b = m.identity.unpack(detach=True)
                kernel_w[..., m._center, m._center] += kernel_id * bn_w.view(-1, 1)
                kernel_b += bn_b
            # 声明合并后的卷积核
            m.conv_main = nn.Conv2d(**m.conv_main._config, bias=True)
            m.conv_main.weight.data, m.conv_main.bias.data = kernel_w, kernel_b
            # 删除被合并的分支
            for attr in ['conv_1x1', 'identity']: setattr(m, attr, None)

然后设计一个集成模型进行验证:

  • merge 函数是否改变了网络结构
  • 重参数化前后,模型的运算结果是否一致
  • 重参数化后,模型的推理速度是否有所提升
def timer(repeat=1, avg=True):
    import time
    repeat = int(repeat) if isinstance(repeat, float) else repeat

    def decorator(func):
        def handler(*args, **kwargs):
            start = time.time()
            for _ in range(max([repeat, 1])):
                result = func(*args, **kwargs)
            cost = time.time() - start
            if avg: cost /= repeat
            print(f'{cost * 1e3:.3f} ms')
            return result

        return handler

    return decorator


class Random_Model(nn.Module):

    def __init__(self, c1=3, c_=8, deploy=False):
        super(Random_Model, self).__init__()
        self.model = nn.Sequential(
            RepConv(c1, c_, deploy=deploy),
            RepConv(c_, c_, k=1, deploy=deploy),
            RepConv(c_, c_, g=2, deploy=deploy)
        )

    @timer(10)
    def forward(self, x):
        return self.model(x)[0].sum(dim=0)


if __name__ == '__main__':
    model = Random_Model(deploy=False).eval()
    print(model, '\n')

    # 为 BatchNorm 初始化随机参数
    for m in filter(lambda m: isinstance(m, BatchNorm), model.modules()):
        m.running_mean.data, m.running_var.data, \
        m.weight.data, m.bias.data = torch.rand([4, m.num_features])

    image = torch.rand([1, 3, 5, 5])

    # 使用训练结构进行测试
    print(model(image), '\n')

    # 调用 RepConv 的类方法, 合并分支
    RepConv.merge(model)
    print(model, '\n')

    # 使用推理结构进行测试
    print(model(image), '\n')

合并分支之前的输出:

2.500 ms
tensor([[ 9.5302,  9.0414,  9.1825, 10.1263,  7.4633],
        [10.0533, 11.0839,  9.5532, 10.5910,  7.2358],
        [ 9.1334, 11.0128,  9.5313, 11.7521,  7.4852],
        [ 9.9997, 11.0532, 11.3155,  9.9294,  7.6453],
        [ 9.4239,  9.9330, 10.2139,  9.2274,  7.7268]], grad_fn=<SumBackward1>) 

合并分支之后的输出:

0.303 ms
tensor([[ 9.5302,  9.0414,  9.1825, 10.1263,  7.4633],
        [10.0533, 11.0839,  9.5532, 10.5910,  7.2358],
        [ 9.1334, 11.0128,  9.5313, 11.7521,  7.4852],
        [ 9.9997, 11.0532, 11.3155,  9.9294,  7.6453],
        [ 9.4239,  9.9330, 10.2139,  9.2274,  7.7268]], grad_fn=<SumBackward1>) 

合并分支之前的模型: 

Random_Model(
  (model): Sequential(
    (0): RepConv(
      (conv_main): Conv(
        (conv): Sequential(
          (conv): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (act): Identity()
      )
      (conv_1x1): Conv(
        (conv): Sequential(
          (conv): Conv2d(3, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (act): Identity()
      )
      (act): SiLU()
    )
    (1): RepConv(
      (conv_main): Conv(
        (conv): Sequential(
          (conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (act): Identity()
      )
      (identity): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (2): RepConv(
      (conv_main): Conv(
        (conv): Sequential(
          (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
          (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (act): Identity()
      )
      (conv_1x1): Conv(
        (conv): Sequential(
          (conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), groups=2, bias=False)
          (bn): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (act): Identity()
      )
      (identity): BatchNorm(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
  )
)  

合并分支之后的模型:

Random_Model(
  (model): Sequential(
    (0): RepConv(
      (conv_main): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_1x1): None
      (act): SiLU()
    )
    (1): RepConv(
      (conv_main): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
      (identity): None
      (act): SiLU()
    )
    (2): RepConv(
      (conv_main): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
      (conv_1x1): None
      (identity): None
      (act): SiLU()
    )
  )

猜你喜欢

转载自blog.csdn.net/qq_55745968/article/details/125887670