ICCV 2021 | 性能炸裂的通道剪枝算法ResRep(Keras复现)

清华大学&旷世科技

Lossless CNN Channel Pruning via Decoupling Remembering and Forgetting

paper:https://arxiv.org/pdf/2007.03260v3.pdf

code:https://hub.fastgit.org/DingXiaoH/ResRep

摘要 

        提出了一种新的无损通道修剪(又称滤波器修剪)方法ResRep,其目的是通过减小卷积层的宽度(输出通道数)来精简卷积神经网络(CNN)。受关于记忆和遗忘依赖性的神经生物学研究的启发,论文建议将CNN参数化为记忆部分和获取部分,前者学习保持性能,后者学习效率。通过在前者上使用规则SGD训练参数化模型,而在后者上使用带有惩罚梯度的新更新规则,我们实现了结构化稀疏性,能够等效地将重新参数化的模型转换为具有更窄层的原始架构。这种方法不同于传统的基于学习的剪枝范式,传统的剪枝范式对参数施加惩罚以产生结构化稀疏性,这可能会抑制记忆所必需的参数。该方法将ImageNet上具有76.15%准确度的标准ResNet-50精简为仅具有45%浮点且无准确度Hydrop的较窄的ResNet-50,这是第一个以如此高的压缩比实现无损修剪的方法。

论文背景

  大多数静态通道修剪方法可分为两类。

(1)修剪-微调方法:通过一些测量从训练有素的模型中识别和修剪不重要的通道,这可能会导致显著的精度下降,然后进行微调以恢复性能。一些方法测量特征重要性并逐步修剪,这可以看作是重复的修剪-微调迭代。然而,一个主要缺点是修剪后的模型很难微调,并且最终的精度无法保证。这些方法剪后的模型很容易陷入糟糕的局部极小值,有时甚至无法与从头开始训练的相同结构的对应模型达到类似的精度水平。于是ResRep诞生了,其消除了微调的需要。像ResRep中的剪枝操作是训练参数的数学等价转换,它不会导致性能下降,因此不需要微调。

(2)基于学习的修剪方法:利用定制的学习过程来减少修剪造成的精度损失。除了以上提到的基于惩罚的范式之外,要将一些惩罚归零在这些通道中,其他一些方法通过元学习、对抗学习等进行修剪。与这些复杂的方法相比,ResRep可以易于端到端部署和训练。

论文主要思想

        ResRep利用卷积之间具有线性变化的原理,通过在原始卷积后面加上一个1x1的卷积,将原CNN等价拆分成负责“记忆”(初始化1x1卷积参数为单位矩阵,保持原始精度不变)的部分和负责“遗忘”(去掉某些通道)的部分,前者进行“记忆训练”(不改变原目标函数、不改变训练超参、不改变更新规则),后者进行“遗忘训练”(一种基于SGD的魔改更新规则,即Res),取得更好的效果(更高压缩率、更少精度损失)。然后,将“记忆”和“遗忘”部分等价合并(原始卷积和1x1的卷积线性组合成一个卷积,并且删除通道)成一个更小的模型。

keras实现 

以下是根据论文和pytorch源码实现的keras版本(支持Tensorflow1.x)。

注:安装kerassurgeon模块 pip install kerassurgeon

keras实现的测试代码链接

compactor实现模块:

class CompactorLayer(Layer):
    def __init__(self, filters,
                 rank=2,
                 kernel_size=1,
                 strides=1,
                 padding='same',
                 data_format=None,
                 dilation_rate=1,
                 activation=None,
                 use_bias=False,
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super(CompactorLayer, self).__init__(**kwargs)
        self.rank = rank
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, 'kernel_size')
        self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
        self.padding = conv_utils.normalize_padding(padding)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, 'dilation_rate')
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(ndim=self.rank + 2)
        self.filters = filters
        self.kernel = None
        self.mask = None
        self.bias = None

    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)
        self.kernel = self.add_weight(shape=kernel_shape,
                                      initializer=initializers.constant(np.eye(self.filters)),
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      )
        self.mask = self.add_weight(shape=(1, 1, 1, self.filters),
                                    initializer=tf.ones_initializer(),
                                    name='mask',
                                    trainable=False
                                    )
        self.bias = None
        # Set input spec.
        self.input_spec = InputSpec(ndim=self.rank + 2,
                                    axes={channel_axis: input_dim})
        self.built = True

    def call(self, inputs, **kwargs):
        temp_kernel = self.kernel * self.mask
        if self.rank == 2:
            outputs = K.conv2d(
                inputs,
                temp_kernel,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0],) + tuple(new_space) + (self.filters,)
        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            return (input_shape[0], self.filters) + tuple(new_space)

Lasso实现模块:

class Lasso(keras.regularizers.Regularizer):
    def __init__(self, l1=0.):
        self.l1 = K.cast_to_floatx(l1)

    def __call__(self, x):
        regularization = 0.
        regularization += K.sum(self.l1 * x / K.sqrt(K.sum(K.square(x), axis=[0, 1, 2], keepdims=True)))

        return regularization

    def get_config(self):
        return {
            'l1': float(self.l1),
        }


def lasso(learn=0.01):
    return Lasso(l1=learn)

等价合并部分:


def combine_conv_compact(model, conv_name, compact_name):
    conv_kxk_weights = model.get_layer(conv_name).get_weights()
    if len(conv_kxk_weights) > 1:
        conv_kxk_weights, conv_kxk_bias = conv_kxk_weights
    else:
        conv_kxk_weights = conv_kxk_weights[0]
        conv_kxk_bias = np.zeros((conv_kxk_weights.shape[-1],))

    compact_kxk_weights = model.get_layer(compact_name).get_weights()
    compact_kxk_mask = compact_kxk_weights[1]
    compact_kxk_weights = compact_kxk_weights[0]
    conv_kxk_weights = conv_kxk_weights
    conv_kxk_bias = np.reshape(conv_kxk_bias, (1, 1, 1, conv_kxk_bias.shape[0]))
    with tf.Session() as sess:
        conv_1x1 = tf.convert_to_tensor(compact_kxk_weights)
        conv_kxk_w = tf.convert_to_tensor(conv_kxk_weights)
        conv_kxk_b = tf.convert_to_tensor(conv_kxk_bias)
        weight = K.conv2d(conv_kxk_w, conv_1x1, padding='same', data_format='channels_last').eval()
        bias = K.conv2d(conv_kxk_b, conv_1x1, padding='same', data_format='channels_last').eval()

    bias = np.sum(bias, axis=(0, 1, 2))
    model.get_layer(conv_name).set_weights([weight, bias])
    model = delete_layer(model, model.get_layer(compact_name), copy=True)
    del_channel = np.argwhere(compact_kxk_mask.reshape(-1) == 0)
    layer = model.get_layer(conv_name)
    model = delete_channels(model, layer, del_channel, copy=True)
    return model

声明:本内容来源网络,版权属于原作者,图片来源原论文。如有侵权,联系删除。

创作不易,欢迎大家点赞评论收藏关注!(想看更多最新的模型压缩文献欢迎关注浏览我的博客)

猜你喜欢

转载自blog.csdn.net/u011447962/article/details/120316394