合并bn层到conv或FC层原理介绍及代码实现 合并bn层到conv或FC层原理介绍及代码实现

合并bn层到conv或FC层原理介绍及代码实现

1.bn合并的必要性:

   bn层即batch-norm层,一般是深度学习中用于加速训练速度和一种方法,一般放置在卷积层(conv层)或者全连接层之后,将数据归一化并加速了训练拟合速度。但是bn层虽然在深度学习模型训练时起到了一定的积极作用,但是在预测时因为凭空多了一些层,影响了整体的计算速度并占用了更多内存或者显存空间。所以我们设想如果能将bn层合并到相邻的卷积层或者全连接层之后就好了,于是就有了这篇文章所提到的工作。


2.bn合并本身的数学原理:

                      bn层一般在神经网络中‘所处的位置如下图所示:


如上图可以看到,bn层的位置一般在conv(or Fc)层的后面,也有一些情况bn在conv(or Fc)层的前面。我们先来两种情况分别来考虑。


2.1 bn层在conv层之后的情形

bn合并的原理,可以由下两张图所示:



               bn层进行数据处理的过程

这张图的表示,将一个数据X,进行bn层的操作和计算得到的结果。



这张图表示,第一部分代表bn层处理之后接着卷基层的操作结果,第二部分表示将bn层合并到卷积层之后,卷积层w和b的变化。


2.2  bn在前,卷积在后的合并方式

       这种情况下,FC层的合并方式和之前2.1的结果类似,但是bn在前,conv在后的情形,因为conv存在pad的情形,所以无法合并。


3.卷积和bn合并的代码实现

3.1 caffe版本(该版本是我从网络获取的,如侵权删)


  
  
  1. #!/usr/bin/env python
  2. import _init_paths
  3. import numpy as np
  4. import sys
  5. import os
  6. import os.path as osp
  7. import google.protobuf as pb
  8. from argparse import ArgumentParser
  9. import sys
  10. import caffe
  11. def load_and_fill_biases(src_model, src_weights, dst_model, dst_weights):
  12. with open(src_model) as f:
  13. model = caffe.proto.caffe_pb2.NetParameter()
  14. pb.text_format.Merge(f.read(), model)
  15. for i, layer in enumerate(model.layer):
  16. if layer.type == 'Convolution': # or layer.type == 'Scale':
  17. # Add bias layer if needed
  18. if layer.convolution_param.bias_term == False:
  19. layer.convolution_param.bias_term = True
  20. layer.convolution_param.bias_filler.type = 'constant'
  21. layer.convolution_param.bias_filler.value = 0.0
  22. with open(dst_model, 'w') as f:
  23. f.write(pb.text_format.MessageToString(model))
  24. caffe.set_mode_cpu()
  25. net_src = caffe.Net(src_model, src_weights, caffe.TEST)
  26. net_dst = caffe.Net(dst_model, caffe.TEST)
  27. for key in net_src.params.keys():
  28. for i in range(len(net_src.params[key])):
  29. net_dst.params[key][i].data[:] = net_src.params[key][i].data[:]
  30. if dst_weights is not None:
  31. # Store params
  32. pass
  33. return net_dst
  34. def merge_conv_and_bn(net, i_conv, i_bn, i_scale):
  35. # This is based on Kyeheyon's work
  36. assert(i_conv != None)
  37. assert(i_bn != None)
  38. def copy_double(data):
  39. return np.array(data, copy= True, dtype=np.double)
  40. key_conv = net._layer_names[i_conv]
  41. key_bn = net._layer_names[i_bn]
  42. key_scale = net._layer_names[i_scale] if i_scale else None
  43. # Copy
  44. bn_mean = copy_double(net.params[key_bn][ 0].data)
  45. bn_variance = copy_double(net.params[key_bn][ 1].data)
  46. num_bn_samples = copy_double(net.params[key_bn][ 2].data)
  47. # and Invalidate the BN layer
  48. net.params[key_bn][ 0].data[:] = 0
  49. net.params[key_bn][ 1].data[:] = 1
  50. net.params[key_bn][ 2].data[:] = 1
  51. if num_bn_samples[ 0] == 0:
  52. num_bn_samples[ 0] = 1
  53. if net.params.has_key(key_scale):
  54. print 'Combine {:s} + {:s} + {:s}'.format(key_conv, key_bn, key_scale)
  55. scale_weight = copy_double(net.params[key_scale][ 0].data)
  56. scale_bias = copy_double(net.params[key_scale][ 1].data)
  57. net.params[key_scale][ 0].data[:] = 1
  58. net.params[key_scale][ 1].data[:] = 0
  59. else:
  60. print 'Combine {:s} + {:s}'.format(key_conv, key_bn)
  61. scale_weight = 1
  62. scale_bias = 0
  63. weight = copy_double(net.params[key_conv][ 0].data)
  64. bias = copy_double(net.params[key_conv][ 1].data)
  65. alpha = scale_weight / np.sqrt(bn_variance / num_bn_samples[ 0] + np.finfo(np.double).eps)
  66. net.params[key_conv][ 1].data[:] = bias * alpha + (scale_bias - (bn_mean / num_bn_samples[ 0]) * alpha)
  67. for i in range(len(alpha)):
  68. net.params[key_conv][ 0].data[i] = weight[i] * alpha[i]
  69. def merge_batchnorms_in_net(net):
  70. # for each BN
  71. for i, layer in enumerate(net.layers):
  72. if layer.type != 'BatchNorm':
  73. continue
  74. l_name = net._layer_names[i]
  75. l_bottom = net.bottom_names[l_name]
  76. assert(len(l_bottom) == 1)
  77. l_bottom = l_bottom[ 0]
  78. l_top = net.top_names[l_name]
  79. assert(len(l_top) == 1)
  80. l_top = l_top[ 0]
  81. can_be_absorbed = True
  82. # Search all (bottom) layers
  83. for j in xrange(i - 1, -1, -1):
  84. tops_of_j = net.top_names[net._layer_names[j]]
  85. if l_bottom in tops_of_j:
  86. if net.layers[j].type not in [ 'Convolution', 'InnerProduct']:
  87. can_be_absorbed = False
  88. else:
  89. # There must be only one layer
  90. conv_ind = j
  91. break
  92. if not can_be_absorbed:
  93. continue
  94. # find the following Scale
  95. scale_ind = None
  96. for j in xrange(i + 1, len(net.layers)):
  97. bottoms_of_j = net.bottom_names[net._layer_names[j]]
  98. if l_top in bottoms_of_j:
  99. if scale_ind:
  100. # Followed by two or more layers
  101. scale_ind = None
  102. break
  103. if net.layers[j].type in [ 'Scale']:
  104. scale_ind = j
  105. top_of_j = net.top_names[net._layer_names[j]][ 0]
  106. if top_of_j == bottoms_of_j[ 0]:
  107. # On-the-fly => Can be merged
  108. break
  109. else:
  110. # Followed by a layer which is not 'Scale'
  111. scale_ind = None
  112. break
  113. merge_conv_and_bn(net, conv_ind, i, scale_ind)
  114. return net
  115. def process_model(net, src_model, dst_model, func_loop, func_finally):
  116. with open(src_model) as f:
  117. model = caffe.proto.caffe_pb2.NetParameter()
  118. pb.text_format.Merge(f.read(), model)
  119. for i, layer in enumerate(model.layer):
  120. map( lambda x: x(layer, net, model, i), func_loop)
  121. map( lambda x: x(net, model), func_finally)
  122. with open(dst_model, 'w') as f:
  123. f.write(pb.text_format.MessageToString(model))
  124. # Functions to remove (redundant) BN and Scale layers
  125. to_delete_empty = []
  126. def pick_empty_layers(layer, net, model, i):
  127. if layer.type not in [ 'BatchNorm', 'Scale']:
  128. return
  129. bottom = layer.bottom[ 0]
  130. top = layer.top[ 0]
  131. if (bottom != top):
  132. # Not supperted yet
  133. return
  134. if layer.type == 'BatchNorm':
  135. zero_mean = np.all(net.params[layer.name][ 0].data == 0)
  136. one_var = np.all(net.params[layer.name][ 1].data == 1)
  137. #length_is_1 = (net.params['conv1_1/bn'][2].data == 1) or (net.params[layer.name][2].data == 0)
  138. length_is_1 = (net.params[layer.name][ 2].data == 1)
  139. if zero_mean and one_var and length_is_1:
  140. print 'Delete layer: {}'.format(layer.name)
  141. to_delete_empty.append(layer)
  142. if layer.type == 'Scale':
  143. no_scaling = np.all(net.params[layer.name][ 0].data == 1)
  144. zero_bias = np.all(net.params[layer.name][ 1].data == 0)
  145. if no_scaling and zero_bias:
  146. print 'Delete layer: {}'.format(layer.name)
  147. to_delete_empty.append(layer)
  148. def remove_empty_layers(net, model):
  149. map(model.layer.remove, to_delete_empty)
  150. # A function to add 'engine: CAFFE' param into 1x1 convolutions
  151. def set_engine_caffe(layer, net, model, i):
  152. if layer.type == 'Convolution':
  153. if layer.convolution_param.kernel_size == 1\
  154. or (layer.convolution_param.kernel_h == layer.convolution_param.kernel_w == 1):
  155. layer.convolution_param.engine = dict(layer.convolution_param.Engine.items())[ 'CAFFE']
  156. def main(args):
  157. # Set default output file names
  158. if args.output_model is None:
  159. file_name = osp.splitext(args.model)[ 0]
  160. args.output_model = file_name + '_inference.prototxt'
  161. if args.output_weights is None:
  162. file_name = osp.splitext(args.weights)[ 0]
  163. args.output_weights = file_name + '_inference.caffemodel'
  164. net = load_and_fill_biases(args.model, args.weights, args.model + '.temp.pt', None)
  165. net = merge_batchnorms_in_net(net)
  166. process_model(net, args.model + '.temp.pt', args.output_model,
  167. [pick_empty_layers, set_engine_caffe],
  168. [remove_empty_layers])
  169. # Store params
  170. net.save(args.output_weights)
  171. if __name__ == '__main__':
  172. parser = ArgumentParser(
  173. description= "Generate Batch Normalized model for inference")
  174. parser.add_argument( 'model', help= "The net definition prototxt")
  175. parser.add_argument( 'weights', help= "The weights caffemodel")
  176. parser.add_argument( '--output_model')
  177. parser.add_argument( '--output_weights')
  178. args = parser.parse_args()
  179. main(args)

3.2 mxnet版本实现(conv_no_bias=True的时候会有问题,此代码我自己实现)


  
  
  1. import sys, argparse
  2. import find_mxnet, find_caffe
  3. import mxnet as mx
  4. import caffe
  5. import pdb
  6. import json
  7. import numpy as np
  8. import copy
  9. def merge_bn_into_conv_or_fc(json_str,net_param):
  10. json_obj, nodes,names,old_num_to_name,inputs = load_json(json_str)
  11. #json_str = json.dumps(json_obj, indent=4)
  12. name_to_num = dict([(v,k) for k,v in old_num_to_name.iteritems()])
  13. bn_name_list = [] # for store the bn_name
  14. conv_name_list = [] # for store the conv_name
  15. for i in range(len(json_obj[ 'nodes'])):
  16. # seach batch-norm and conv(fc)
  17. if json_obj[ 'nodes'][i][ 'op'] == "BatchNorm":
  18. may_conv_index= json_obj[ 'nodes'][i][ 'inputs'][ 0][ 0]
  19. # search conv or fc before the batchnorm
  20. if json_obj[ 'nodes'][may_conv_index][ 'op'] in [ "Convolution", "FullyConnected"]:
  21. bn_name_list.append(json_obj[ 'nodes'][i][ 'name'])
  22. conv_name_list.append(json_obj[ 'nodes'][may_conv_index][ 'name'])
  23. if len(bn_name_list)!=len(conv_name_list) or len(bn_name_list)<= 0:
  24. print "error, len(bn_name_list) should be equal len(conv_name_list)"
  25. exit()
  26. for i in range(len(bn_name_list)):
  27. print i
  28. json_obj, nodes,names,old_num_to_name,inputs = load_json(json_str)
  29. name_to_num = dict([(v,k) for k,v in old_num_to_name.iteritems()])
  30. # bn_name,bn-eps,bn-fixgamma
  31. bn_index = name_to_num[bn_name_list[i]]
  32. bn_name = json_obj[ 'nodes'][bn_index][ 'name']
  33. bn_eps = float(json_obj[ 'nodes'][bn_index][ 'param'][ 'eps'])
  34. bn_fixgamma = bool(json_obj[ 'nodes'][bn_index][ 'param'][ 'fix_gamma'])
  35. # conv_name,no_bias
  36. conv_index = name_to_num[conv_name_list[i]]
  37. conv_name = json_obj[ 'nodes'][conv_index][ 'name']
  38. conv_no_bias = bool(json_obj[ 'nodes'][conv_index][ 'param'][ 'no_bias'])
  39. # use merge_bn_conv_after_bn
  40. net_param = copy.deepcopy(merge_bn_conv_after(net_param=net_param, conv_name=conv_name, bn_name=bn_name, fix_gamma=bn_fixgamma, no_bias=conv_no_bias, eps=bn_eps))
  41. json_str = copy.deepcopy(merge_bn_conv_after_bn_json(json_str=json_str,conv_name=conv_name,bn_name=bn_name,fix_gamma=bn_fixgamma,no_bias=conv_no_bias,eps=bn_eps))
  42. return json_str,net_param
  43. def load_json(json_str):
  44. #json_obj = json.load(json_file) # dict contain "nodes arg_nodes, heads"
  45. json_obj = json.loads(json_str) # dict contain "nodes arg_nodes, heads"
  46. nodes = json_obj[ 'nodes'] # a list,lens = num of layers
  47. names = [node[ 'name'] for node in nodes] # names
  48. old_num_to_name = dict(enumerate(names)) # dict
  49. name_to_num = dict([(v,k) for k,v in old_num_to_name.iteritems()])
  50. inputs = [node[ 'inputs'] for node in nodes]
  51. return json_obj ,nodes,names,old_num_to_name,inputs
  52. def merge_bn_conv_after_bn_json(json_str,conv_name,bn_name,fix_gamma=False,no_bias=False,eps=0.001):
  53. json_obj, nodes,names,old_num_to_name,inputs = load_json(json_str)
  54. name_to_num = dict([(v,k) for k,v in old_num_to_name.iteritems()])
  55. # cal the conv and bn index
  56. conv_index = name_to_num[conv_name]
  57. bn_index = name_to_num[bn_name]
  58. for i in range(len(json_obj[ 'nodes'])):
  59. if len(json_obj[ 'nodes'][i][ 'inputs'])<= 0:
  60. continue # when inputs =[]
  61. # change bn_node to conv_node
  62. input_list= json_obj[ 'nodes'][i][ 'inputs']
  63. for j in range(len(input_list)):
  64. if input_list[j][ 0] == bn_index:
  65. input_list[j][ 0] = conv_index
  66. else:
  67. pass
  68. json_obj[ 'nodes'][i][ 'inputs'] = input_list
  69. # for change bn-layer to a param not op
  70. if json_obj[ 'nodes'][i][ 'name'] == bn_name:
  71. json_obj[ 'nodes'][i] = copy.deepcopy(json_obj[ 'nodes'][i -1])
  72. json_obj[ 'nodes'][i][ 'name'] = bn_name
  73. # change_name
  74. if no_bias== True:
  75. # print json_obj['nodes'][int(bn_index)-1]['name']
  76. json_obj[ 'nodes'][int(bn_index) -1][ 'name'] = conv_name + '_bias'
  77. # print json_obj['nodes'][int(bn_index)-1]['name']
  78. json_obj[ 'nodes'][conv_index][ 'param'][ 'no_bias'] = "False"
  79. list_add = []
  80. list_add.append(int(bn_index) -1)
  81. #list_add.append(int(bn_index))
  82. list_add.append( 0)
  83. json_obj[ 'nodes'][conv_index][ 'inputs'].append(list_add)
  84. # change bn_beta_name to conv_bias
  85. json_obj[ 'nodes'][int(bn_index) -1][ 'name'] = conv_name + '_bias'
  86. # return json_obj
  87. # return json_str
  88. return json.dumps(json_obj, indent= 4)
  89. # merge conv and after bn
  90. def merge_bn_conv_after(net_param,conv_name,bn_name, fix_gamma = False, no_bias = False, eps=0.001):
  91. gamma = net_param[ 'arg:'+ bn_name + '_gamma'].asnumpy() # scale gamma
  92. if fix_gamma == True: # fix_gamma = true
  93. gamma *= 0
  94. gamma += 1
  95. beta = net_param[ 'arg:'+ bn_name + '_beta'].asnumpy() # scale beta
  96. mov_mean = net_param[ 'aux:'+ bn_name + '_moving_mean'].asnumpy() # bn-mean
  97. mov_var = net_param[ 'aux:' + bn_name + '_moving_var'].asnumpy() # bn var
  98. mov_std = np.sqrt(mov_var + eps) # calulate the std from var
  99. # conv_weights and conv_bias before merge
  100. part_0_conv_weight = net_param[ 'arg:' + conv_name + '_weight'].asnumpy()
  101. output_channel =part_0_conv_weight.shape[ 0] # output_channel
  102. pdb.set_trace()
  103. if no_bias == True: # fill the bias to zero , it is no use has_something wrong
  104. # update the conv_bias and conv_weights
  105. part_0_conv_bias = np.zeros((output_channel,),dtype = np.float64)
  106. #pdb.set_trace()
  107. for i in range(output_channel): # shape[0] is output_channel_num, weight.shape = [out,in,kernel,kernel]
  108. part_0_conv_weight[i,:,:,:] *= float(gamma[i]/mov_std[i]) # update conv_weight
  109. # part_0_conv_bias[i] *= float(gamma[i]/mov_std[i]) # update conv_bias
  110. part_0_conv_bias[i] += beta[i]-float(gamma[i]*mov_mean[i]/mov_std[i]) # update conv_bias
  111. #pdb.set_trace()
  112. else:
  113. # update the conv_bias and conv_weights
  114. part_0_conv_bias = net_param[ 'arg:' + conv_name+ '_bias'].asnumpy()
  115. for i in range(output_channel): # shape[0] is output_channel_num, weight.shape = [out,in,kernel,kernel]
  116. part_0_conv_weight[i,:,:,:] *= float(gamma[i]/mov_std[i]) # update conv_weight
  117. part_0_conv_bias[i] *= float(gamma[i]/mov_std[i]) # update conv_bias
  118. part_0_conv_bias[i] += beta[i]-float(gamma[i]*mov_mean[i]/mov_std[i]) # update conv_bias
  119. # update the net_param
  120. net_param[ 'arg:' + conv_name + '_weight']= mx.nd.array(part_0_conv_weight)
  121. if no_bias== True:
  122. #net_param['arg:' + bn_name + '_bias'] = mx.nd.array(part_0_conv_bias)
  123. net_param[ 'arg:' + conv_name + '_bias'] = mx.nd.array(part_0_conv_bias)
  124. #pdb.set_trace()
  125. else:
  126. net_param[ 'arg:' + conv_name + '_bias'] = mx.nd.array(part_0_conv_bias)
  127. #print net_param.keys()
  128. return net_param
  129. # input_mx_model + input_mx_epoch = resnet/base-symbol.json and resnet/base-14-9999.params
  130. input_param = sys.argv[ 1] # such as resnet/base-14-9999.params
  131. input_json = file(sys.argv[ 2]) # such as resnet/base-14.json
  132. net_param = mx.nd.load(input_param)
  133. new_json_str,new_param = merge_bn_into_conv_or_fc(input_json.read(),net_param)
  134. #new_json = merge_bn_conv_after_bn_json(json_file = input_json, bn_name="part_0_bn_conv1", conv_name = "part_0_conv",fix_gamma = True, no_bias = False, eps=0.001)
  135. #net_param = merge_bn_conv_after(net_param = net_param, bn_name="part_0_bn_conv1", conv_name = "part_0_conv",fix_gamma = True, no_bias = False, eps=0.001)
  136. #net_param = merge_bn_conv_after(net_param = net_param, bn_name="part_0_bn0", conv_name = "part_0_conv0",fix_gamma = True, no_bias = True ) # for resnet_divide4
  137. #new_json_str = json.dumps(new_json,indent=4)
  138. open((sys.argv[ 2]).replace( ".json", "_change.json"), "w").write(new_json_str)
  139. mx.nd.save(input_param.replace( ".params", "_change.params"),new_param)









猜你喜欢

转载自blog.csdn.net/m0_37192554/article/details/85049433