模型优化:BatchNorm合并到卷积中

1.bn合并的必要性:

   bn层即batch-norm层,一般是深度学习中用于加速训练速度和一种方法,一般放置在卷积层(conv层)或者全连接层之后,将数据归一化并加速了训练拟合速度。但是bn层虽然在深度学习模型训练时起到了一定的积极作用,但是在预测时因为凭空多了一些层,影响了整体的计算速度并占用了更多内存或者显存空间。所以我们设想如果能将bn层合并到相邻的卷积层或者全连接层之后就好了,于是就有了这篇文章所提到的工作。

2.bn合并本身的数学原理:

                      bn层一般在神经网络中‘所处的位置如下图所示:

如上图可以看到,bn层的位置一般在conv(or Fc)层的后面,也有一些情况bn在conv(or Fc)层的前面。我们先来两种情况分别来考虑。

2.1 bn层在conv层之后的情形

bn合并的原理,可以由下两张图所示:

               bn层进行数据处理的过程

这张图的表示,将一个数据X,进行bn层的操作和计算得到的结果。

这张图表示,第一部分代表bn层处理之后接着卷基层的操作结果,第二部分表示将bn层合并到卷积层之后,卷积层w和b的变化。

注意点:conv,Bn,Scale,层之间的top和bottom的名字要相同

2.2  bn在前,卷积在后的合并方式

       这种情况下,FC层的合并方式和之前2.1的结果类似,但是bn在前,conv在后的情形,因为conv存在pad的情形,所以无法合并。

3.卷积和bn合并的代码实现

3.1 caffe版本(该版本是我从网络获取的,如侵权删)

[python] view plain copy

  1. #!/usr/bin/env python  
  2. import _init_paths  
  3. import numpy as np  
  4. import sys  
  5. import os  
  6. import os.path as osp  
  7. import google.protobuf as pb  
  8. from argparse import ArgumentParser  
  9. import sys  
  10. import caffe  
  11.   
  12.   
  13. def load_and_fill_biases(src_model, src_weights, dst_model, dst_weights):  
  14.     with open(src_model) as f:  
  15.         model = caffe.proto.caffe_pb2.NetParameter()  
  16.         pb.text_format.Merge(f.read(), model)  
  17.   
  18.     for i, layer in enumerate(model.layer):  
  19.         if layer.type == 'Convolution': # or layer.type == 'Scale':  
  20.             # Add bias layer if needed  
  21.             if layer.convolution_param.bias_term == False:  
  22.                 layer.convolution_param.bias_term = True  
  23.                 layer.convolution_param.bias_filler.type = 'constant'  
  24.                 layer.convolution_param.bias_filler.value = 0.0  
  25.   
  26.     with open(dst_model, 'w') as f:  
  27.         f.write(pb.text_format.MessageToString(model))  
  28.   
  29.     caffe.set_mode_cpu()  
  30.     net_src = caffe.Net(src_model, src_weights, caffe.TEST)  
  31.     net_dst = caffe.Net(dst_model, caffe.TEST)  
  32.     for key in net_src.params.keys():  
  33.         for i in range(len(net_src.params[key])):  
  34.             net_dst.params[key][i].data[:] = net_src.params[key][i].data[:]  
  35.   
  36.     if dst_weights is not None:  
  37.         # Store params  
  38.         pass  
  39.   
  40.     return net_dst  
  41.   
  42.   
  43. def merge_conv_and_bn(net, i_conv, i_bn, i_scale):  
  44.     # This is based on Kyeheyon's work  
  45.     assert(i_conv != None)  
  46.     assert(i_bn != None)  
  47.   
  48.     def copy_double(data):  
  49.         return np.array(data, copy=True, dtype=np.double)  
  50.   
  51.     key_conv = net._layer_names[i_conv]  
  52.     key_bn = net._layer_names[i_bn]  
  53.     key_scale = net._layer_names[i_scale] if i_scale else None  
  54.   
  55.     # Copy  
  56.     bn_mean = copy_double(net.params[key_bn][0].data)  
  57.     bn_variance = copy_double(net.params[key_bn][1].data)  
  58.     num_bn_samples = copy_double(net.params[key_bn][2].data)  
  59.   
  60.     # and Invalidate the BN layer  
  61.     net.params[key_bn][0].data[:] = 0  
  62.     net.params[key_bn][1].data[:] = 1  
  63.     net.params[key_bn][2].data[:] = 1  
  64.     if num_bn_samples[0] == 0:  
  65.         num_bn_samples[0] = 1  
  66.   
  67.     if net.params.has_key(key_scale):  
  68.         print 'Combine {:s} + {:s} + {:s}'.format(key_conv, key_bn, key_scale)  
  69.         scale_weight = copy_double(net.params[key_scale][0].data)  
  70.         scale_bias = copy_double(net.params[key_scale][1].data)  
  71.         net.params[key_scale][0].data[:] = 1  
  72.         net.params[key_scale][1].data[:] = 0  
  73.     else:  
  74.         print 'Combine {:s} + {:s}'.format(key_conv, key_bn)  
  75.         scale_weight = 1  
  76.         scale_bias = 0  
  77.   
  78.     weight = copy_double(net.params[key_conv][0].data)  
  79.     bias = copy_double(net.params[key_conv][1].data)  
  80.     alpha = scale_weight / np.sqrt(bn_variance / num_bn_samples[0] + np.finfo(np.double).eps)  
  81.     net.params[key_conv][1].data[:] = bias * alpha + (scale_bias - (bn_mean / num_bn_samples[0]) * alpha)  
  82.     for i in range(len(alpha)):  
  83.         net.params[key_conv][0].data[i] = weight[i] * alpha[i]  
  84.   
  85. def merge_batchnorms_in_net(net):  
  86.     # for each BN  
  87.     for i, layer in enumerate(net.layers):  
  88.         if layer.type != 'BatchNorm':  
  89.             continue  
  90.   
  91.         l_name = net._layer_names[i]  
  92.   
  93.         l_bottom = net.bottom_names[l_name]  
  94.         assert(len(l_bottom) == 1)  
  95.         l_bottom = l_bottom[0]  
  96.         l_top = net.top_names[l_name]  
  97.         assert(len(l_top) == 1)  
  98.         l_top = l_top[0]  
  99.   
  100.         can_be_absorbed = True  
  101.   
  102.         # Search all (bottom) layers  
  103.         for j in xrange(i - 1, -1, -1):  
  104.             tops_of_j = net.top_names[net._layer_names[j]]  
  105.             if l_bottom in tops_of_j:  
  106.                 if net.layers[j].type not in ['Convolution', 'InnerProduct']:  
  107.                     can_be_absorbed = False  
  108.                 else:  
  109.                     # There must be only one layer  
  110.                     conv_ind = j  
  111.                     break  
  112.   
  113.         if not can_be_absorbed:  
  114.             continue  
  115.   
  116.         # find the following Scale  
  117.         scale_ind = None  
  118.         for j in xrange(i + 1, len(net.layers)):  
  119.             bottoms_of_j = net.bottom_names[net._layer_names[j]]  
  120.             if l_top in bottoms_of_j:  
  121.                 if scale_ind:  
  122.                     # Followed by two or more layers  
  123.                     scale_ind = None  
  124.                     break  
  125.   
  126.                 if net.layers[j].type in ['Scale']:  
  127.                     scale_ind = j  
  128.   
  129.                     top_of_j = net.top_names[net._layer_names[j]][0]  
  130.                     if top_of_j == bottoms_of_j[0]:  
  131.                         # On-the-fly => Can be merged  
  132.                         break  
  133.   
  134.                 else:  
  135.                     # Followed by a layer which is not 'Scale'  
  136.                     scale_ind = None  
  137.                     break  
  138.   
  139.   
  140.         merge_conv_and_bn(net, conv_ind, i, scale_ind)  
  141.   
  142.     return net  
  143.   
  144.   
  145. def process_model(net, src_model, dst_model, func_loop, func_finally):  
  146.     with open(src_model) as f:  
  147.         model = caffe.proto.caffe_pb2.NetParameter()  
  148.         pb.text_format.Merge(f.read(), model)  
  149.   
  150.   
  151.     for i, layer in enumerate(model.layer):  
  152.         map(lambda x: x(layer, net, model, i), func_loop)  
  153.   
  154.     map(lambda x: x(net, model), func_finally)  
  155.   
  156.     with open(dst_model, 'w') as f:  
  157.         f.write(pb.text_format.MessageToString(model))  
  158.   
  159.   
  160. # Functions to remove (redundant) BN and Scale layers  
  161. to_delete_empty = []  
  162. def pick_empty_layers(layer, net, model, i):  
  163.     if layer.type not in ['BatchNorm', 'Scale']:  
  164.         return  
  165.   
  166.     bottom = layer.bottom[0]  
  167.     top = layer.top[0]  
  168.   
  169.     if (bottom != top):  
  170.         # Not supperted yet  
  171.         return  
  172.   
  173.     if layer.type == 'BatchNorm':  
  174.         zero_mean = np.all(net.params[layer.name][0].data == 0)  
  175.         one_var = np.all(net.params[layer.name][1].data == 1)  
  176.         #length_is_1 = (net.params['conv1_1/bn'][2].data == 1) or (net.params[layer.name][2].data == 0)  
  177.         length_is_1 =  (net.params[layer.name][2].data == 1)  
  178.   
  179.         if zero_mean and one_var and length_is_1:  
  180.             print 'Delete layer: {}'.format(layer.name)  
  181.             to_delete_empty.append(layer)  
  182.   
  183.     if layer.type == 'Scale':  
  184.         no_scaling = np.all(net.params[layer.name][0].data == 1)  
  185.         zero_bias = np.all(net.params[layer.name][1].data == 0)  
  186.   
  187.         if no_scaling and zero_bias:  
  188.             print 'Delete layer: {}'.format(layer.name)  
  189.             to_delete_empty.append(layer)  
  190.   
  191. def remove_empty_layers(net, model):  
  192.     map(model.layer.remove, to_delete_empty)  
  193.   
  194.   
  195. # A function to add 'engine: CAFFE' param into 1x1 convolutions  
  196. def set_engine_caffe(layer, net, model, i):  
  197.     if layer.type == 'Convolution':  
  198.         if layer.convolution_param.kernel_size == 1\  
  199.             or (layer.convolution_param.kernel_h == layer.convolution_param.kernel_w == 1):  
  200.             layer.convolution_param.engine = dict(layer.convolution_param.Engine.items())['CAFFE']  
  201.   
  202.   
  203. def main(args):  
  204.     # Set default output file names  
  205.     if args.output_model is None:  
  206.         file_name = osp.splitext(args.model)[0]  
  207.         args.output_model = file_name + '_inference.prototxt'  
  208.     if args.output_weights is None:  
  209.         file_name = osp.splitext(args.weights)[0]  
  210.         args.output_weights = file_name + '_inference.caffemodel'  
  211.   
  212.     net = load_and_fill_biases(args.model, args.weights, args.model + '.temp.pt', None)  
  213.   
  214.     net = merge_batchnorms_in_net(net)  
  215.   
  216.     process_model(net, args.model + '.temp.pt', args.output_model,  
  217.                   [pick_empty_layers, set_engine_caffe],  
  218.                   [remove_empty_layers])  
  219.   
  220.     # Store params  
  221.     net.save(args.output_weights)  
  222.   
  223.   
  224. if __name__ == '__main__':  
  225.     parser = ArgumentParser(  
  226.             description="Generate Batch Normalized model for inference")  
  227.     parser.add_argument('model', help="The net definition prototxt")  
  228.     parser.add_argument('weights', help="The weights caffemodel")  
  229.     parser.add_argument('--output_model')  
  230.     parser.add_argument('--output_weights')  
  231.     args = parser.parse_args()  
  232.     main(args)  

猜你喜欢

转载自blog.csdn.net/a8039974/article/details/83686633