如何正确的理解RPN网络的训练

刚开始学Faster RCNN时,遇到这么一个困惑不知其他人有没有:
RPN网络在程序中的训练是如何进行的?它都训练了网络中的哪些部分?其实这些我们如果不看源码都很难真正理解!
我们以Faster-RCNN_TF的源码为例,以下代码取自./lib/networks/VGGnet_train.py

 #========= RPN ============
 #以下代码的先后顺序我调整了一下,便于理解
 (self.feed('conv5_3')
     .conv(3,3,512,1,1,name='rpn_conv/3x3')
     .conv(1,1,len(anchor_scales)*3*2 ,1 , 1, padding='VALID', relu = False, name='rpn_cls_score'))

 (self.feed('rpn_conv/3x3')
     .conv(1,1,len(anchor_scales)*3*4, 1, 1, padding='VALID', relu = False, name='rpn_bbox_pred'))

 #重点:
 # Loss of rpn_cls & rpn_boxes
 #anchor_target_layer的返回值'rpn-data'
 #分别是:rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights
 #这四个值可以看作是标注
 #使用rpn_labels 和 'rpn_cls_score_reshape'计算损失
 #使用rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights和'rpn_bbox_pred'计算损失
 #所以产生'rpn_cls_score_reshape'和'rpn_bbox_pred'的网络就是RPN网络需要训练的模型
 #也就是说上面那两段代码所建立的网络,当然了还包括更以前的VGG16的各个可以被训练的层!
 #而下面这段代码,anchor_target_layer()的作用就是产生训练目标,在测试时这里没用!
 (self.feed('rpn_cls_score','gt_boxes','im_info','data')
     .anchor_target_layer(_feat_stride, anchor_scales, name = 'rpn-data' ))

损失函数的计算:

# RPN
# classification loss
rpn_cls_score = tf.reshape(self.net.get_output('rpn_cls_score_reshape'),[-1,2])
rpn_label = tf.reshape(self.net.get_output('rpn-data')[0],[-1])
rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score,tf.where(tf.not_equal(rpn_label,-1))),[-1,2])
rpn_label = tf.reshape(tf.gather(rpn_label,tf.where(tf.not_equal(rpn_label,-1))),[-1])
rpn_cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score, labels=rpn_label))

# bounding box regression L1 loss
rpn_bbox_pred = self.net.get_output('rpn_bbox_pred')
rpn_bbox_targets = tf.transpose(self.net.get_output('rpn-data')[1],[0,2,3,1])
rpn_bbox_inside_weights = tf.transpose(self.net.get_output('rpn-data')[2],[0,2,3,1])
rpn_bbox_outside_weights = tf.transpose(self.net.get_output('rpn-data')[3],[0,2,3,1])

rpn_smooth_l1 = self._modified_smooth_l1(3.0, rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights)
rpn_loss_box = tf.reduce_mean(tf.reduce_sum(rpn_smooth_l1, reduction_indices=[1, 2, 3]))

猜你喜欢

转载自blog.csdn.net/wangdongwei0/article/details/81323346
RPN
今日推荐