SSD源码解读1-数据层AnnotatedDataLayer

版权声明:本文为博主原创文章,转载需注明出处。 https://blog.csdn.net/qianqing13579/article/details/80146281

年后到现在,利用自己的业余时间断断续续将caffe的SSD源码看完了,虽然中间由于工作原因暂停了一段时间,但最终还算顺利完成了,SSD源码的阅读也是今年的年度计划中比较重要的一项内容,完成了还是很有成就感的。阅读完代码后,一个最大的体会就是之前论文中很多困惑我的细节现在豁然开朗了,哈哈。

在阅读代码期间,每次遇到困惑我的地方,我会反复思考,琢磨,利用走路,吃饭的时间思考,也常常会在宿舍里来回踱步,现在我对阅读代码有了一个新的体会。当你阅读一段对你来说很难的代码的时候,不要害怕,你只要静下心来,将一段很难的代码拆分成N个子块,然后针对每个子块各个击破,等你将所有子块都击破了,然后再将所有子块串联起来连接成一个整体,再从整体思考这段代码,会有更加深刻的理解。当然这个过程起初会很难,因为起初很多东西你都不懂,就像我阅读SSD代码的时候,起初很难,要了解很多细节,但是只要有耐心有毅力,慢慢你会发现,你对这些内容越来越熟悉,你也会感到越来越轻松,直到最后你豁然开朗,发现这段很难的代码也不过如此,那种感觉实在是太美妙了。

五一的第一天稍微整理了一下SSD源码的阅读笔记,写成博客,与大家一起分享交流,由于SSD源码比较复杂,加上时间精力有限,不可能对每个细节都有深入的理解,博客中有不足之处,希望大家能够提出宝贵的意见。

这篇博客是SSD源码解读系列的第1篇,对数据层进行解读。

SSD源码阅读的时候,我对SSD源码创建了QT工程,这样方便阅读,SSD源码的QT工程我上传到CSDN了,该工程用QT可以直接打开的,大家可以直接下载该QT工程阅读,提高阅读效率。
点击下载


数据层AnnotatedDataLayer源码解读

#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#endif  // USE_OPENCV
#include <stdint.h>

#include <algorithm>
#include <map>
#include <vector>

#include "caffe/data_transformer.hpp"
#include "caffe/layers/annotated_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/sampler.hpp"

namespace caffe {

template <typename Dtype>
AnnotatedDataLayer<Dtype>::AnnotatedDataLayer(const LayerParameter& param)
  : BasePrefetchingDataLayer<Dtype>(param),
    reader_(param) {
}

template <typename Dtype>
AnnotatedDataLayer<Dtype>::~AnnotatedDataLayer() {
  this->StopInternalThread();
}

template <typename Dtype>
void AnnotatedDataLayer<Dtype>::DataLayerSetUp(
    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  const int batch_size = this->layer_param_.data_param().batch_size();
  const AnnotatedDataParameter& anno_data_param = this->layer_param_.annotated_data_param();

  // 读取所有数据增强采样参数
  for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) 
  {
    batch_samplers_.push_back(anno_data_param.batch_sampler(i));
  }
  label_map_file_ = anno_data_param.label_map_file();

  // Make sure dimension is consistent within batch.
  const TransformationParameter& transform_param = this->layer_param_.transform_param();
  if (transform_param.has_resize_param()) 
  {
    if (transform_param.resize_param().resize_mode() ==ResizeParameter_Resize_mode_FIT_SMALL_SIZE) 
    {
      CHECK_EQ(batch_size, 1)<< "Only support batch size of 1 for FIT_SMALL_SIZE.";
    }
  }

  // 读取一个数据,并读取数据的shape,初始化top的shape和prefetch的shape(比如数据大小为300x300)
  // AnnotatedDatum包含了数据和标注(标注包含了label和bounding box)
  // Read a data point, and use it to initialize the top blob.
  AnnotatedDatum& anno_datum = *(reader_.full().peek()); // reader_中读到的数据就是输入的数据(包括图像数据和boundingbox坐标)
  // Use data_transformer to infer the expected blob shape from anno_datum.
  vector<int> top_shape =this->data_transformer_->InferBlobShape(anno_datum.datum());
  this->transformed_data_.Reshape(top_shape);
  // Reshape top[0] and prefetch_data according to the batch_size.
  top_shape[0] = batch_size;
  top[0]->Reshape(top_shape);

  // 预读线程中的图像数据
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) 
  {
    this->prefetch_[i].data_.Reshape(top_shape);
  }
  LOG(INFO) << "output data size: " << top[0]->num() << ","<< top[0]->channels() << "," << top[0]->height() << ","<< top[0]->width();

  // label
  if (this->output_labels_) 
  {
    // 生成数据的时候是有类型的 anno_datum.set_type(AnnotatedDatum_AnnotationType_BBOX);
    has_anno_type_ = anno_datum.has_type() || anno_data_param.has_anno_type();
    vector<int> label_shape(4, 1);
    if (has_anno_type_) 
    {
      anno_type_ = anno_datum.type();
      if (anno_data_param.has_anno_type()) 
      {
        // If anno_type is provided in AnnotatedDataParameter, replace
        // the type stored in each individual AnnotatedDatum.
        LOG(WARNING) << "type stored in AnnotatedDatum is shadowed.";
        anno_type_ = anno_data_param.anno_type();
      }
      // Infer the label shape from anno_datum.AnnotationGroup().
      int num_bboxes = 0;

      // 读取该图像的所有box数量
      if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) 
      {
        // Since the number of bboxes can be different for each image,
        // we store the bbox information in a specific format. In specific:
        // All bboxes are stored in one spatial plane (num and channels are 1)
        // And each row contains one and only one box in the following format:
        // [item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]
        // Note: Refer to caffe.proto for details about group_label and
        // instance_id.
        for (int g = 0; g < anno_datum.annotation_group_size(); ++g) {
          num_bboxes += anno_datum.annotation_group(g).annotation_size();
        }
        label_shape[0] = 1;
        label_shape[1] = 1;
        // BasePrefetchingDataLayer<Dtype>::LayerSetUp() requires to call
        // cpu_data and gpu_data for consistent prefetch thread. Thus we make
        // sure there is at least one bbox.
        label_shape[2] = std::max(num_bboxes, 1);
        label_shape[3] = 8;
      } 
      else 
      {
        LOG(FATAL) << "Unknown annotation type.";
      }
    } 
    else 
    {
      label_shape[0] = batch_size;
    }
    top[1]->Reshape(label_shape);

    // 预读线程中的label数据
    for (int i = 0; i < this->PREFETCH_COUNT; ++i) 
    {
      this->prefetch_[i].label_.Reshape(label_shape);
    }
  }
}

// This function is called on prefetch thread
template<typename Dtype>
void AnnotatedDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) 
{
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());

  // Reshape according to the first anno_datum of each batch
  // on single input batches allows for inputs of varying dimension.
  const int batch_size = this->layer_param_.data_param().batch_size();
  const AnnotatedDataParameter& anno_data_param =this->layer_param_.annotated_data_param();
  const TransformationParameter& transform_param =this->layer_param_.transform_param();

  // 初始化transformed_data_和 batch->data_的大小
  AnnotatedDatum& anno_datum = *(reader_.full().peek());
  vector<int> top_shape =this->data_transformer_->InferBlobShape(anno_datum.datum());// 3x300x300
  this->transformed_data_.Reshape(top_shape); // transformed_data_存储一幅图像,对于SSD300,transformed_data_大小为:[1,3,300,300]
  top_shape[0] = batch_size;
  batch->data_.Reshape(top_shape); // batch->data_存储batchsize个图像,对于SSD300,batch->data_大小为[batchsize,3,300,300]

  Dtype* top_data = batch->data_.mutable_cpu_data();
  Dtype* top_label = NULL;  // suppress warnings about uninitialized variables
  if (this->output_labels_ && !has_anno_type_) 
  {
    top_label = batch->label_.mutable_cpu_data();
  }

  // Store transformed annotation.
  map<int, vector<AnnotationGroup> > all_anno; // batchsize中每一幅图像以及对应的标注
  int num_bboxes = 0;

  for (int item_id = 0; item_id < batch_size; ++item_id) 
  {
    timer.Start();

    // 获取一幅图像,并做相应的预处理(比如加入扰动)
    AnnotatedDatum& anno_datum = *(reader_.full().pop("Waiting for data"));
    read_time += timer.MicroSeconds();
    timer.Start();
    AnnotatedDatum distort_datum;
    AnnotatedDatum* expand_datum = NULL;
    if (transform_param.has_distort_param()) 
    {
      distort_datum.CopyFrom(anno_datum);
      this->data_transformer_->DistortImage(anno_datum.datum(),
                                            distort_datum.mutable_datum());
      if (transform_param.has_expand_param()) 
      {
        expand_datum = new AnnotatedDatum();
        this->data_transformer_->ExpandImage(distort_datum, expand_datum);
      } 
      else 
      {
        expand_datum = &distort_datum;
      }
    } 
    else 
    {
      if (transform_param.has_expand_param()) 
      {
        expand_datum = new AnnotatedDatum();
        this->data_transformer_->ExpandImage(anno_datum, expand_datum);
      } 
      else 
      {
        expand_datum = &anno_datum;
      }
    }

    AnnotatedDatum* sampled_datum = NULL;
    bool has_sampled = false;


    if (batch_samplers_.size() > 0)
    {
      /* 1. 首先进行数据增强(对应论文2.2 Training部分的Data augmentation)
       * 对于batchsize中的每一幅图像,为每个采样器(batch_sampler)生成max_sample个boundingbox(候选框)
       * 每个采样器生成的boundingbox与目标的IOU=0.1,0.3,0.5,0.7,0.9,这个与论文的描述是一致的
       * 示例:
          batch_sampler
          {
            sampler
            {
              min_scale: 0.3
              max_scale: 1.0
              min_aspect_ratio: 0.5
              max_aspect_ratio: 2.0
            }
            sample_constraint
            {
              min_jaccard_overlap: 0.7
            }
            max_sample: 1
            max_trials: 50
          }
       *  对于该采样器,随机生成的满足条件的boundingbox与图像中任一目标的IOU>0.7
       *  注意:
       *    1. 生成的boundingbox坐标是归一化的坐标,这样不受resize的影响,目标检测的回归都是采用的这种形式(比如MTCNN)
       *    2. 随机生成boundingbox的时候,根据每个batch_sampler的参数:尺度,宽高比,每个采样器最多尝试max_trials次
       *
       */
      vector<NormalizedBBox> sampled_bboxes;// 生成的是归一化的坐标
      GenerateBatchSamples(*expand_datum, batch_samplers_, &sampled_bboxes);


      /*2. 从生成的所有bounding box中随机挑选一个bounding box
       * 裁剪出该bounding box对应的图像(大小就是sampled_bboxes[rand_idx]在原图中的大小)并计算该bounding box中所有目标的坐标以及类别
       * 注意:
       *    1. bounding box中目标的坐标=(原图中ground truth的坐标-该bounding box的坐标)/(bounding box的边长)
       *     这里groundtruth与boundingbox的坐标都相对于原图,在mtcnn中也是采用了该计算方式
       *
       */
      if (sampled_bboxes.size() > 0)
      {
        int rand_idx = caffe_rng_rand() % sampled_bboxes.size();
        sampled_datum = new AnnotatedDatum();
        this->data_transformer_->CropImage(*expand_datum,sampled_bboxes[rand_idx],sampled_datum);

        has_sampled = true;
      } 
      else 
      {
        sampled_datum = expand_datum;
      }
    }
    else 
    {
      sampled_datum = expand_datum;
    }
    CHECK(sampled_datum != NULL);
    timer.Start();
    vector<int> shape =this->data_transformer_->InferBlobShape(sampled_datum->datum());
    if (transform_param.has_resize_param()) 
    {
        // 不执行该部分
      if (transform_param.resize_param().resize_mode() ==ResizeParameter_Resize_mode_FIT_SMALL_SIZE) 
      {
        this->transformed_data_.Reshape(shape);
        batch->data_.Reshape(shape);
        top_data = batch->data_.mutable_cpu_data();
      } 
      else 
      {
        CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,shape.begin() + 1));
      }
    } 
    else 
    {
      CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,
            shape.begin() + 1));
    }
    // Apply data transformations (mirror, scale, crop...)
    int offset = batch->data_.offset(item_id);
    this->transformed_data_.set_cpu_data(top_data + offset);
    vector<AnnotationGroup> transformed_anno_vec;
    if (this->output_labels_) 
    {
      if (has_anno_type_) 
      {
        // Make sure all data have same annotation type.
        CHECK(sampled_datum->has_type()) << "Some datum misses AnnotationType.";
        if (anno_data_param.has_anno_type()) 
        {
          sampled_datum->set_type(anno_type_);
        } 
        else 
        {
          CHECK_EQ(anno_type_, sampled_datum->type()) <<
              "Different AnnotationType.";
        }

        // Transform datum and annotation_group at the same time
        transformed_anno_vec.clear();

        // AnnotatedDatum,Blob<float>,vector<AnnotationGroup>

        /* 3. 将crop出来的AnnotatedDatum转换为数据部分和标注部分
         *  数据部分会resize到数据层设置的大小(比如300x300)并保存到top[0]中
         *  标注是所有目标在图像中的坐标
         *
         * 注意:
         *  1. 这里的图像并不一定是原始crop的图像,如果transform_param有crop_size这个参数,原来crop出来的图像会再次crop的
         *  2. 由于这里对crop出来的图像进行了一次resize,所以如果生成lmdb的时候,进行resize会导致数据层对原图进行两次resize,
         *     这样有可能会影响到目标的宽高比,所以在SFD(Single Shot Scale-invariant Face Detector)中,对此处做了一点改进,即在第一步
         *     生成boundingbox的时候,保证每个boundingbox都是正方形,这样resize到300x300的时候就不会改变目标的宽高比
         */
        this->data_transformer_->Transform(*sampled_datum,&(this->transformed_data_),&transformed_anno_vec);
        if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) 
        {
          // Count the number of bboxes.
          // 计算该随机生成的bounding box中有多少目标
          for (int g = 0; g < transformed_anno_vec.size(); ++g) 
          {
            num_bboxes += transformed_anno_vec[g].annotation_size();
          }
        } 
        else 
        {
          LOG(FATAL) << "Unknown annotation type.";
        }

        // batchsize中第item_id个图像的标注
        all_anno[item_id] = transformed_anno_vec;
      } 
      else 
      {
        this->data_transformer_->Transform(sampled_datum->datum(),&(this->transformed_data_));
        // Otherwise, store the label from datum.
        CHECK(sampled_datum->datum().has_label()) << "Cannot find any label.";
        top_label[item_id] = sampled_datum->datum().label();
      }
    } 
    else 
    {
      this->data_transformer_->Transform(sampled_datum->datum(),&(this->transformed_data_));
    }
    // clear memory
    if (has_sampled) {
      delete sampled_datum;
    }
    if (transform_param.has_expand_param()) {
      delete expand_datum;
    }
    trans_time += timer.MicroSeconds();

    // 将读过的数据再放回去
    reader_.free().push(const_cast<AnnotatedDatum*>(&anno_datum));
  }

  // Store "rich" annotation if needed.
  /*4. 最后将标注信息保存到top[1]中,top[1]的shape:[1,1,numberOfBoxes,8]
   *每一行格式:[item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]
   *这个8维向量表示的含义:batchsize个图像中的第item_id幅图像中的第group_label个类别下的第instance_id个box的坐标为[xmin, ymin, xmax, ymax]
   *
   */
  if (this->output_labels_ && has_anno_type_) 
  {
    vector<int> label_shape(4);
    if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) 
    {
      label_shape[0] = 1;
      label_shape[1] = 1;
      label_shape[3] = 8;
      if (num_bboxes == 0) 
      {
        // Store all -1 in the label.
        label_shape[2] = 1;
        batch->label_.Reshape(label_shape);
        caffe_set<Dtype>(8, -1, batch->label_.mutable_cpu_data());
      } 
      else 
      {

        // num_bboxes就是前面crop出来的所有图像中所有目标的数量
        label_shape[2] = num_bboxes;
        batch->label_.Reshape(label_shape);
        top_label = batch->label_.mutable_cpu_data();
        int idx = 0;

        // 遍历bachsizes中每一幅图像的label信息
        for (int item_id = 0; item_id < batch_size; ++item_id) 
        {
            // 第ite_id幅图像的label信息
          const vector<AnnotationGroup>& anno_vec = all_anno[item_id];
          for (int g = 0; g < anno_vec.size(); ++g) 
          {
            const AnnotationGroup& anno_group = anno_vec[g];

            for (int a = 0; a < anno_group.annotation_size(); ++a) 
            {
              const Annotation& anno = anno_group.annotation(a);
              const NormalizedBBox& bbox = anno.bbox();

              top_label[idx++] = item_id;
              top_label[idx++] = anno_group.group_label();
              top_label[idx++] = anno.instance_id();
              top_label[idx++] = bbox.xmin();
              top_label[idx++] = bbox.ymin();
              top_label[idx++] = bbox.xmax();
              top_label[idx++] = bbox.ymax();
              top_label[idx++] = bbox.difficult();
            }
          }
        }
      }
    }
    else
    {
      LOG(FATAL) << "Unknown annotation type.";
    }
  }
  timer.Stop();
  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

INSTANTIATE_CLASS(AnnotatedDataLayer);
REGISTER_LAYER_CLASS(AnnotatedData);

}  // namespace caffe

数据层中有几个比较重要的函数GenerateBatchSamples(),this->data_transformer_->CropImage(),this->data_transformer_->Transform(),下面对他们详细解读一下

GenerateBatchSamples

void GenerateBatchSamples(const AnnotatedDatum& anno_datum,
                          const vector<BatchSampler>& batch_samplers,
                          vector<NormalizedBBox>* sampled_bboxes) 
{
  sampled_bboxes->clear();

  // 获取groundtruth box
  vector<NormalizedBBox> object_bboxes;
  GroupObjectBBoxes(anno_datum, &object_bboxes); 

  // 对于每个采样器生成多个box
  for (int i = 0; i < batch_samplers.size(); ++i) 
  {
     // Use original image as the source for sampling.
    if (batch_samplers[i].use_original_image()) 
    {
      NormalizedBBox unit_bbox;
      unit_bbox.set_xmin(0);
      unit_bbox.set_ymin(0);
      unit_bbox.set_xmax(1);
      unit_bbox.set_ymax(1);
      GenerateSamples(unit_bbox,  // 单位box
                      object_bboxes,// ground truth box
                      batch_samplers[i], // 采样器
                      sampled_bboxes);
    }
  }
}

void GenerateSamples(const NormalizedBBox& source_bbox, // 单位box
                     const vector<NormalizedBBox>& object_bboxes, // object_bboxes就是该图像中所有的ground truth boxes
                     const BatchSampler& batch_sampler, // 采样器
                     vector<NormalizedBBox>* sampled_bboxes) 
{
  int found = 0;

  // 每个采样器batch_sampler都要尝试max_trials次
  for (int i = 0; i < batch_sampler.max_trials(); ++i) 
  {
      // 每个batch_sampler生成的boundingbox个数大于等于max_sample了,就跳出
    if (batch_sampler.has_max_sample() && found >= batch_sampler.max_sample()) 
    {
      break;
    }

    // Generate sampled_bbox in the normalized space [0, 1].
    // 随机生成一个box
    NormalizedBBox sampled_bbox;
    SampleBBox(batch_sampler.sampler(), &sampled_bbox);

    // Transform the sampled_bbox w.r.t. source_bbox.
    // 转换为在单位box中的坐标,由于都是单位box,所以转换后还是自己
    LocateBBox(source_bbox, sampled_bbox, &sampled_bbox);

    // Determine if the sampled bbox is positive or negative by the constraint.
    // 所有的ground truth 与生成的boundingbox计算IOU,是否满足条件
    if (SatisfySampleConstraint(sampled_bbox, object_bboxes,batch_sampler.sample_constraint())) 
    {
      ++found;
      sampled_bboxes->push_back(sampled_bbox);
    }
  }
}

DataTransformer::CropImage

template<typename Dtype>
void DataTransformer<Dtype>::CropImage(const AnnotatedDatum& anno_datum,
                                       const NormalizedBBox& bbox,
                                       AnnotatedDatum* cropped_anno_datum)
{
  // 首先crop数据:将bbox映射到原图的像素坐标并crop出该区域对应的图像,然后将crop出来的图像保存到cropped_anno_datum
  CropImage(anno_datum.datum(), bbox, cropped_anno_datum->mutable_datum());
  cropped_anno_datum->set_type(anno_datum.type());

  // 根据crop_bbox 转换 annotation
  // cropped_anno_datum保持的就是该图像中每个类别下所有bbox与该crop_bbox的偏移(用(ground truth-crop_bbox)/crop_bbox)
  // Transform the annotation according to crop_bbox.
  const bool do_resize = false;
  const bool do_mirror = false;
  NormalizedBBox crop_bbox;
  ClipBBox(bbox, &crop_bbox); // 边界判断
  TransformAnnotation(anno_datum, do_resize, crop_bbox, do_mirror,cropped_anno_datum->mutable_annotation_group());
}

template<typename Dtype>
void DataTransformer<Dtype>::TransformAnnotation(
    const AnnotatedDatum& anno_datum, const bool do_resize,
    const NormalizedBBox& crop_bbox, const bool do_mirror,
    RepeatedPtrField<AnnotationGroup>* transformed_anno_group_all) 
{
  const int img_height = anno_datum.datum().height();
  const int img_width = anno_datum.datum().width();
  if (anno_datum.type() == AnnotatedDatum_AnnotationType_BBOX) 
  {
    // Go through each AnnotationGroup.
    // 计算每个类别下所有bbox与该crop_bbox的偏移,其实就算计算随机生成的box中所有目标的坐标
    for (int g = 0; g < anno_datum.annotation_group_size(); ++g) 
    {
      const AnnotationGroup& anno_group = anno_datum.annotation_group(g);
      AnnotationGroup transformed_anno_group;
      bool has_valid_annotation = false;

      // 每个类别的所有Annotation
      for (int a = 0; a < anno_group.annotation_size(); ++a) 
      {
        const Annotation& anno = anno_group.annotation(a);
        const NormalizedBBox& bbox = anno.bbox();

        // Adjust bounding box annotation.
        NormalizedBBox resize_bbox = bbox;

        // 这里do_resize和do_mirror都是false
        if (do_resize && param_.has_resize_param()) 
        {
          CHECK_GT(img_height, 0);
          CHECK_GT(img_width, 0);
          UpdateBBoxByResizePolicy(param_.resize_param(), img_width, img_height,&resize_bbox);
        }
        if (param_.has_emit_constraint() &&!MeetEmitConstraint(crop_bbox, resize_bbox,param_.emit_constraint())) 
        {
          continue;
        }
        // ProjectBBox计算ground truth与随机生成的bbox的偏移(只计算有交集的)
        NormalizedBBox proj_bbox; // proj_bbox 就是偏移,就是做回归用的
        if (ProjectBBox(crop_bbox, resize_bbox, &proj_bbox))
        {
          has_valid_annotation = true;
          Annotation* transformed_anno =transformed_anno_group.add_annotation();
          transformed_anno->set_instance_id(anno.instance_id());
          NormalizedBBox* transformed_bbox = transformed_anno->mutable_bbox();
          transformed_bbox->CopyFrom(proj_bbox);

          if (do_mirror) 
          {
            Dtype temp = transformed_bbox->xmin();
            transformed_bbox->set_xmin(1 - transformed_bbox->xmax());
            transformed_bbox->set_xmax(1 - temp);
          }
          if (do_resize && param_.has_resize_param()) 
          {
            ExtrapolateBBox(param_.resize_param(), img_height, img_width,crop_bbox, transformed_bbox);
          }
        }
      }
      // Save for output.
      if (has_valid_annotation)
      {
        // 遍历完该类别下所有ground truth,设置label
        transformed_anno_group.set_group_label(anno_group.group_label());
        transformed_anno_group_all->Add()->CopyFrom(transformed_anno_group);
      }
    }
  }
  else
  {
    LOG(FATAL) << "Unknown annotation type.";
  }
}

DataTransformer::Transform

template<typename Dtype>
void DataTransformer<Dtype>::Transform(
    const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
    vector<AnnotationGroup>* transformed_anno_vec) 
{
  bool do_mirror;
  Transform(anno_datum, transformed_blob, transformed_anno_vec, &do_mirror);
}

template<typename Dtype>
void DataTransformer<Dtype>::Transform(
    const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
    vector<AnnotationGroup>* transformed_anno_vec, bool* do_mirror) {
  RepeatedPtrField<AnnotationGroup> transformed_anno_group_all;
  Transform(anno_datum, transformed_blob, &transformed_anno_group_all,
            do_mirror);
  for (int g = 0; g < transformed_anno_group_all.size(); ++g) {
    transformed_anno_vec->push_back(transformed_anno_group_all.Get(g));
  }
}

template<typename Dtype>
void DataTransformer<Dtype>::Transform(
    const AnnotatedDatum& anno_datum, Blob<Dtype>* transformed_blob,
    RepeatedPtrField<AnnotationGroup>* transformed_anno_group_all,
    bool* do_mirror) 
{

  // Transform datum.
  /* 转换数据
   * 如果DataTransformer参数中没有crop_size,则crop_box还是原图大小(归一化大小,这里就是(0,0,1,1))
   * 如果数据层有resize参数,就会缩放
   *
   */
  const Datum& datum = anno_datum.datum();
  NormalizedBBox crop_bbox;
  Transform(datum, transformed_blob, &crop_bbox, do_mirror);


  // Transform annotation.
  /* 转换标注
   * 计算最后cropped出来的图像中所有目标的坐标(transform_param如果有crop_size,会crop出一块区域)
  */
  const bool do_resize = true;
  TransformAnnotation(anno_datum, do_resize, crop_bbox, *do_mirror,
                      transformed_anno_group_all);
}

数据层的源码大概就是这样,大家有什么疑问的,可以留言一起讨论。

2018-4-29 22:44:01
Last updated: 2018-5-1 10:53:20


非常感谢您的阅读,如果您觉得这篇文章对您有帮助,欢迎扫码进行赞赏。
这里写图片描述

猜你喜欢

转载自blog.csdn.net/qianqing13579/article/details/80146281
今日推荐