MimicNorm-替代BN层显存降低20% | Weight Mean and Last BN Layer Mimic the Dynamic of Batch Normalization

看到了显存降低了20%,但是没看到推理速度方面的比较…
论文地址:https://arxiv.org/pdf/2010.09278.pdf
Github地址:https://github.com/Kid-key/MimicNorm

在这里插入图片描述

Abstract:

大量的实验已经验证了批归一化(BN)层在受益于收敛和泛化方面的成功。但是,BN需要额外的内存和浮点计算。此外,BN在微批量上是不准确的,因为它取决于批量统计信息。在本文中,我们通过简化BN正则化来解决这些问题,同时保留了BN层的两个基本影响,即数据去相关和自适应学习率。我们提出了一种新的标准化方法,称为MimicNorm,以提高网络训练的收敛性和效率。 MimicNorm仅由两个轻量级操作组成,包括修改后的权重均值操作(从权重参数张量中减去平均值)和损失函数之前的一个BN层(最后一个BN层)。我们利用神经正切核(NTK)理论来证明我们的权重均值运算会白化激活并将网络过渡到BN层等混沌状态,从而导致增强的收敛性。最后一个BN层提供自动调整的学习率,并提高了准确性。实验结果表明,MimicNorm在包括ResNets和ShuffleNet之类的轻量级网络在内的各种网络结构上均达到了相似的精度,并减少了约20%的内存消耗。

MimicNorm:

在这里插入图片描述
主要就是两个改动:
1.在basic block里面去掉BN,改为带weight mean的Conv + scalar层 + ReLU
2.在损失函数之前加BN层来自动调整学习率

直接看代码:

1.带weight mean的conv2d

def meanweigh(module):
 #   for name, module in container.named_modules():     
        if isinstance(module, nn.Conv2d) and module.in_channels*module.kernel_size[0] * module.kernel_size[1]/module.groups>50:  
            datavalue=module.weight.data
            meanvalue=datavalue.mean([1,2,3],True)
            module.weight.data=(datavalue-meanvalue)

2.basic block,可以看出移除了BN层后,使用self.scale进行缩放操作,即将Conv->BN->ReLU->Conv改为Conv#->ReLU->Conv#->Scale,其中Conv#代表weight mean后的Conv

class MyPassLayer(nn.Module):
    def __init__(self,inival=1.2):
        super(MyPassLayer, self).__init__()
        self.scale = nn.Parameter(torch.ones(1)*inival)
    def forward(self, x):
        out=x*self.scale
        return out

class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34
    """

    #BasicBlock and BottleNeck block 
    #have different output size
    #we use class attribute expansion
    #to distinct
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        global subnetnum
        subnetnum += 1.
        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True),
            #MyPassLayer(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=True),
            #MyPassLayer(out_channels * BasicBlock.expansion),
           # MyScaleLayer(0.1),
        )
        self.scale =nn.Parameter(torch.ones(1)/subnetnum*1.2**2)
        #shortcut
        self.shortcut = nn.Sequential()

        #the shortcut output dimension is not the same with residual function
        #use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=True),
                MyPassLayer() # nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )
        
    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x)*self.scale + self.shortcut(x))

3.网络结尾之前增加last BN层:

	if CBN:
        self.lastbn=nn.BatchNorm1d(num_classes,affine=False)
            
def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
       # output=self.bn(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)
        if self.CBN:
            output = self.lastbn(output)   
        #if self.CBN:
       #     output = output1

        return output 

Experiments:

1.Cifar100:
在这里插入图片描述
2.ImageNet:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_42096202/article/details/109474213