深度学习(28)——YOLO系列(7)

深度学习(28)——YOLO系列(7)

咱就是说,需要源码请造访:Jane的GitHub在这里
上午没写完的,下午继续,是一个小尾巴。其实上午把训练的关键部分和数据的关键部分都写完了,现在就是写一下推理部分
在推理过程为了提高效率,速度更快:

detect 全过程

1.1 attempt_load(weights)

  • weights是加载的yolov7之前训练好的权重
  • 刚开始load以后还有BN,没有合并的
    在这里插入图片描述
  • 关键在下面的fuse()

1.2 model.fuse()

在这里插入图片描述

# 很隐蔽,刚开始我没想到接口是在这里的
    def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
        print('Fusing layers... ')
        for m in self.model.modules():
            if isinstance(m, RepConv):
                #print(f" fuse_repvgg_block")
                m.fuse_repvgg_block()
            elif isinstance(m, RepConv_OREPA):
                #print(f" switch_to_deploy")
                m.switch_to_deploy()
            elif type(m) is Conv and hasattr(m, 'bn'):
                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                delattr(m, 'bn')  # remove batchnorm
                m.forward = m.fuseforward  # update forward
            elif isinstance(m, (IDetect, IAuxDetect)):
                m.fuse()
                m.forward = m.fuseforward
        self.info()
        return self

当遇到conv后面一定是有BN的,所以
在这里插入图片描述

1.3 fuse_conv_and_bn(conv,bn)

  • 先定义一个新的conv【和原来传入的是一样的inputsize,outputsize和kernel】
    在这里插入图片描述
  • 先得到w_conv: w_conv = conv.weight.clone().view(conv.out_channels, -1)
  • 得到w_bn: w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))bn.weight 就是以下公式中的gamma,sigma平方是方差bn.running_var在这里插入图片描述
  • 得到w_fuse: fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  • 得到b_conv,因为在学习过程中bias我们都设置为0,所以: b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  • 得到b_bn :b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))【bn.bias是上面公式中的β,μ为均值bn.running_mean】
  • 计算b_fusefusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
def fuse_conv_and_bn(conv, bn):
    # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # prepare filters bn.weight 对应论文中的gamma   bn.bias对应论文中的beta bn.running_mean则是对于当前batch size的数据所统计出来的平均值 bn.running_var是对于当前batch size的数据所统计出来的方差
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

    # prepare spatial bias
    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

    return fusedconv

1.4 Repvgg_block

把Repvgg中的卷积和BN合在一起

  • 原来的block↓
    在这里插入图片描述
  • 融合rbr_dense后:
    在这里插入图片描述
  • 融合rbr_1*1后:
    在这里插入图片描述

1.5 将1* 1卷积padding成3* 3

在这里插入图片描述
padding后
在这里插入图片描述
所有的都改变以后:model长这样——>
在这里插入图片描述
在这里插入图片描述
OK,这次真没啦,886~~~~

猜你喜欢

转载自blog.csdn.net/qq_43368987/article/details/131703296