Caffe框架源码剖析—数据层DataLayer

原文链接:http://blog.csdn.net/tianrolin/article/details/52522662

Caffe网络正向传导时,首先进行的是DataLayer数据层的传导。该层从文件读取数据,加载至它的上一层卷积层。反向传播时,因为数据层不需要反传,所以它的Backward_cpu()和Backward_gpu()都是空函数。下面看一下DataLayer类图关系。

首先从父类BaseDataLayer开始看源码,base_data_layer.hpp头文件:

  1. template <typename Dtype>  
  2. class BaseDataLayer : public Layer<Dtype> {  
  3.  public:  
  4.   // 构造函数  
  5.   explicit BaseDataLayer(const LayerParameter& param);  
  6.   // 实现一般数据层构建,并调用DataLayerSetup函数  
  7.   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  8.       const vector<Blob<Dtype>*>& top);  
  9.   // 数据层可在并行时共享  
  10.   virtual inline bool ShareInParallel() const { return true; }  
  11.   // 空的构建函数(该函数为虚函数,待子类重载)  
  12.   virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  13.       const vector<Blob<Dtype>*>& top) {}  
  14.   // 数据层没有bottom层,因此Reshape函数为空函数  
  15.   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  16.       const vector<Blob<Dtype>*>& top) {}  
  17.   
  18.   // 反向传播,空函数  
  19.   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  20.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  21.   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  22.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  23.   
  24.  protected:  
  25.   TransformationParameter transform_param_;  
  26.   shared_ptr<DataTransformer<Dtype> > data_transformer_;  
  27.   // 是否包含有输出标签  
  28.   bool output_labels_;  
  29. };  
template <typename Dtype>
class BaseDataLayer : public Layer<Dtype> {
 public:
  // 构造函数
  explicit BaseDataLayer(const LayerParameter& param);
  // 实现一般数据层构建,并调用DataLayerSetup函数
  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  // 数据层可在并行时共享
  virtual inline bool ShareInParallel() const { return true; }
  // 空的构建函数(该函数为虚函数,待子类重载)
  virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {}
  // 数据层没有bottom层,因此Reshape函数为空函数
  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {}

  // 反向传播,空函数
  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}
  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}

 protected:
  TransformationParameter transform_param_;
  shared_ptr<DataTransformer<Dtype> > data_transformer_;
  // 是否包含有输出标签
  bool output_labels_;
};

base_data_layer.cpp实现文件

  1. // 构造函数  
  2. template <typename Dtype>  
  3. BaseDataLayer<Dtype>::BaseDataLayer(const LayerParameter& param)  
  4.     : Layer<Dtype>(param),  
  5.       transform_param_(param.transform_param()) {  
  6. }  
  7.   
  8. template <typename Dtype>  
  9. void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  10.       const vector<Blob<Dtype>*>& top) {  
  11.   // 如果top层size大于1,则包含有标签  
  12.   if (top.size() == 1) {  
  13.     output_labels_ = false;  
  14.   } else {  
  15.     output_labels_ = true;  
  16.   }  
  17.   data_transformer_.reset(  
  18.       new DataTransformer<Dtype>(transform_param_, this->phase_));  
  19.   // 初始化随机数生成器  
  20.   data_transformer_->InitRand();  
  21.   // 调用构建虚函数  
  22.   DataLayerSetUp(bottom, top);  
  23. }  
// 构造函数
template <typename Dtype>
BaseDataLayer<Dtype>::BaseDataLayer(const LayerParameter& param)
    : Layer<Dtype>(param),
      transform_param_(param.transform_param()) {
}

template <typename Dtype>
void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  // 如果top层size大于1,则包含有标签
  if (top.size() == 1) {
    output_labels_ = false;
  } else {
    output_labels_ = true;
  }
  data_transformer_.reset(
      new DataTransformer<Dtype>(transform_param_, this->phase_));
  // 初始化随机数生成器
  data_transformer_->InitRand();
  // 调用构建虚函数
  DataLayerSetUp(bottom, top);
}

接下来看一下子类BasePrefetchingDataLayer类,该类不仅继承了BaseDataLayer类,还继承自InternalThread类。因此该类重载了InternalThread类的虚函数InternalThreadEntry()。

  1. template <typename Dtype>  
  2. class BasePrefetchingDataLayer :  
  3.     public BaseDataLayer<Dtype>, public InternalThread {  
  4.  public:  
  5.   explicit BasePrefetchingDataLayer(const LayerParameter& param);  
  6.   // 构建函数  
  7.   void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  8.       const vector<Blob<Dtype>*>& top);  
  9.   
  10.   // CPU正向传导函数  
  11.   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  12.       const vector<Blob<Dtype>*>& top);  
  13.   // GPU正向传导函数  
  14.   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,  
  15.       const vector<Blob<Dtype>*>& top);  
  16.   
  17.   // 预取数据块大小  
  18.   static const int PREFETCH_COUNT = 3;  
  19.   
  20.  protected:  
  21.   // 线程函数,虚函数重载  
  22.   virtual void InternalThreadEntry();  
  23.   // 加载batch,纯虚函数,由子类DataLayer实现  
  24.   virtual void load_batch(Batch<Dtype>* batch) = 0;  
  25.   
  26.   Batch<Dtype> prefetch_[PREFETCH_COUNT];  
  27.   BlockingQueue<Batch<Dtype>*> prefetch_free_;  
  28.   BlockingQueue<Batch<Dtype>*> prefetch_full_;  
  29.   
  30.   Blob<Dtype> transformed_data_;  
  31. };  
template <typename Dtype>
class BasePrefetchingDataLayer :
    public BaseDataLayer<Dtype>, public InternalThread {
 public:
  explicit BasePrefetchingDataLayer(const LayerParameter& param);
  // 构建函数
  void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);

  // CPU正向传导函数
  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  // GPU正向传导函数
  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);

  // 预取数据块大小
  static const int PREFETCH_COUNT = 3;

 protected:
  // 线程函数,虚函数重载
  virtual void InternalThreadEntry();
  // 加载batch,纯虚函数,由子类DataLayer实现
  virtual void load_batch(Batch<Dtype>* batch) = 0;

  Batch<Dtype> prefetch_[PREFETCH_COUNT];
  BlockingQueue<Batch<Dtype>*> prefetch_free_;
  BlockingQueue<Batch<Dtype>*> prefetch_full_;

  Blob<Dtype> transformed_data_;
};

base_data_layer.cpp实现文件


  1. template <typename Dtype>  
  2. BasePrefetchingDataLayer<Dtype>::BasePrefetchingDataLayer(  
  3.     const LayerParameter& param)  
  4.     : BaseDataLayer<Dtype>(param),  
  5.       prefetch_free_(), prefetch_full_() {  
  6.   for (int i = 0; i < PREFETCH_COUNT; ++i) {  
  7.     prefetch_free_.push(&prefetch_[i]);  
  8.   }  
  9. }  
  10.   
  11. template <typename Dtype>  
  12. void BasePrefetchingDataLayer<Dtype>::LayerSetUp(  
  13.     const vector<Blob<Dtype>>& bottom, const vector<Blob<Dtype>>& top) {  
  14.   // 先调用父类LayerSetUp  
  15.   BaseDataLayer<Dtype>::LayerSetUp(bottom, top);  
  16.   // 线程开启前先分配内存&显存,防止在某些GPU上报错  
  17.   for (int i = 0; i < PREFETCH_COUNT; ++i) {  
  18.     prefetch_[i].data_.mutable_cpu_data();  
  19.     if (this->output_labels_) {  
  20.       prefetch_[i].label_.mutable_cpu_data();  
  21.     }  
  22.   }  
  23. #ifndef CPU_ONLY  
  24.   if (Caffe::mode() == Caffe::GPU) {  
  25.     for (int i = 0; i < PREFETCH_COUNT; ++i) {  
  26.       prefetch_[i].data_.mutable_gpu_data();  
  27.       if (this->output_labels_) {  
  28.         prefetch_[i].label_.mutable_gpu_data();  
  29.       }  
  30.     }  
  31.   }  
  32. #endif  
  33.   DLOG(INFO) << ”Initializing prefetch”;  
  34.   // 初始化随机数生成器  
  35.   this->data_transformer_->InitRand();  
  36.   // 开启线程  
  37.   StartInternalThread();  
  38.   DLOG(INFO) << ”Prefetch initialized.”;  
  39. }  
  40.   
  41. // 线程函数,由StartInternalThread开启  
  42. template <typename Dtype>  
  43. void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() {  
  44. #ifndef CPU_ONLY  
  45.   // 在GPU上启用stream异步加载  
  46.   cudaStream_t stream;  
  47.   if (Caffe::mode() == Caffe::GPU) {  
  48.     CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));  
  49.   }  
  50. #endif  
  51.   
  52.   try {  
  53.     while (!must_stop()) {  
  54.       Batch<Dtype> batch = prefetch_free_.pop();  
  55.       // 加载batch,该函数由子类DataLayer实现  
  56.       load_batch(batch);  
  57. #ifndef CPU_ONLY  
  58.       if (Caffe::mode() == Caffe::GPU) {  
  59.         batch->data_.data().get()->async_gpu_push(stream);  
  60.         CUDA_CHECK(cudaStreamSynchronize(stream));  
  61.       }  
  62. #endif  
  63.       prefetch_full_.push(batch);  
  64.     }  
  65.   } catch (boost::thread_interrupted&) {  
  66.     // Interrupted exception is expected on shutdown  
  67.   }  
  68. #ifndef CPU_ONLY  
  69.   if (Caffe::mode() == Caffe::GPU) {  
  70.     CUDA_CHECK(cudaStreamDestroy(stream));  
  71.   }  
  72. #endif  
  73. }  
  74.   
  75. // CPU正向传导  
  76. template <typename Dtype>  
  77. void BasePrefetchingDataLayer<Dtype>::Forward_cpu(  
  78.     const vector<Blob<Dtype>>& bottom, const vector<Blob<Dtype>>& top) {  
  79.   Batch<Dtype> batch = prefetch_full_.pop(”Data layer prefetch queue empty”);  
  80.   // Reshape成与batch数据同一维度  
  81.   top[0]->ReshapeLike(batch->data_);  
  82.   // 将batch数据拷贝至top层blob[0]  
  83.   caffe_copy(batch->data_.count(), batch->data_.cpu_data(),  
  84.              top[0]->mutable_cpu_data());  
  85.   DLOG(INFO) << ”Prefetch copied”;  
  86.   // 如果包含输出标签  
  87.   if (this->output_labels_) {  
  88.     // Reshape成batch标签同一维度  
  89.     top[1]->ReshapeLike(batch->label_);  
  90.     // 将batch标签拷贝至top层blob[1]  
  91.     caffe_copy(batch->label_.count(), batch->label_.cpu_data(),  
  92.         top[1]->mutable_cpu_data());  
  93.   }  
  94.   
  95.   prefetch_free_.push(batch);  
  96. }  
  97. // 如果CPU_ONLY模式则禁止Forward_gpu和Backward_gpu函数  
  98. #ifdef CPU_ONLY  
  99. STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward);  
  100. #endif  
template <typename Dtype> 
BasePrefetchingDataLayer<Dtype>::BasePrefetchingDataLayer(
const LayerParameter& param)
: BaseDataLayer<Dtype>(param),
prefetch_free_(), prefetch_full_() {
for (int i = 0; i < PREFETCH_COUNT; ++i) {
prefetch_free_.push(&prefetch_[i]);
}
}

template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>>& bottom, const vector<Blob<Dtype>>& top) {
// 先调用父类LayerSetUp
BaseDataLayer<Dtype>::LayerSetUp(bottom, top);
// 线程开启前先分配内存&显存,防止在某些GPU上报错
for (int i = 0; i < PREFETCH_COUNT; ++i) {
prefetch_[i].data_.mutable_cpu_data();
if (this->output_labels_) {
prefetch_[i].label_.mutable_cpu_data();
}
}

ifndef CPU_ONLY

if (Caffe::mode() == Caffe::GPU) {
for (int i = 0; i < PREFETCH_COUNT; ++i) {
prefetch_[i].data_.mutable_gpu_data();
if (this->output_labels_) {
prefetch_[i].label_.mutable_gpu_data();
}
}
}

endif

DLOG(INFO) << "Initializing prefetch";
// 初始化随机数生成器
this->data_transformer_->InitRand();
// 开启线程
StartInternalThread();
DLOG(INFO) << "Prefetch initialized.";
}

// 线程函数,由StartInternalThread开启
template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() {

ifndef CPU_ONLY

// 在GPU上启用stream异步加载
cudaStream_t stream;
if (Caffe::mode() == Caffe::GPU) {
CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
}

endif

try {
while (!must_stop()) {
Batch<Dtype>* batch = prefetch_free_.pop();
// 加载batch,该函数由子类DataLayer实现
load_batch(batch);

ifndef CPU_ONLY

  if (Caffe::mode() == Caffe::GPU) {
    batch-&gt;data_.data().get()-&gt;async_gpu_push(stream);
    CUDA_CHECK(cudaStreamSynchronize(stream));
  }

endif

  prefetch_full_.push(batch);
}

} catch (boost::thread_interrupted&) {
// Interrupted exception is expected on shutdown
}

ifndef CPU_ONLY

if (Caffe::mode() == Caffe::GPU) {
CUDA_CHECK(cudaStreamDestroy(stream));
}

endif

}

// CPU正向传导
template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>>& bottom, const vector<Blob<Dtype>>& top) {
Batch<Dtype>* batch = prefetch_full_.pop(“Data layer prefetch queue empty”);
// Reshape成与batch数据同一维度
top[0]->ReshapeLike(batch->data_);
// 将batch数据拷贝至top层blob[0]
caffe_copy(batch->data_.count(), batch->data_.cpu_data(),
top[0]->mutable_cpu_data());
DLOG(INFO) << “Prefetch copied”;
// 如果包含输出标签
if (this->output_labels_) {
// Reshape成batch标签同一维度
top[1]->ReshapeLike(batch->label_);
// 将batch标签拷贝至top层blob[1]
caffe_copy(batch->label_.count(), batch->label_.cpu_data(),
top[1]->mutable_cpu_data());
}

prefetch_free_.push(batch);
}
// 如果CPU_ONLY模式则禁止Forward_gpu和Backward_gpu函数

ifdef CPU_ONLY

STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward);

endif


最后分析下最终的子类DataLayer,由于很多方法由它的父类实现了,该类功能很简单了,只重载了两个虚函数DataLayerSetUp()和load_batch()。

  1. template <typename Dtype>  
  2. class DataLayer : public BasePrefetchingDataLayer<Dtype> {  
  3.  public:  
  4.   explicit DataLayer(const LayerParameter& param);  
  5.   virtual ~DataLayer();  
  6.   
  7.   // 构建函数,重载虚函数  
  8.   virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  9.       const vector<Blob<Dtype>*>& top);  
  10.   // DataLayer uses DataReader instead for sharing for parallelism  
  11.   virtual inline bool ShareInParallel() const { return false; }  
  12.   virtual inline const char* type() const { return “Data”; }  
  13.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  14.   virtual inline int MinTopBlobs() const { return 1; }  
  15.   virtual inline int MaxTopBlobs() const { return 2; }  
  16.   
  17.  protected:  
  18.   // 加载batch,重载虚函数  
  19.   virtual void load_batch(Batch<Dtype>* batch);  
  20.   
  21.   // DataReader对象  
  22.   DataReader reader_;  
  23. };  
template <typename Dtype>
class DataLayer : public BasePrefetchingDataLayer<Dtype> {
 public:
  explicit DataLayer(const LayerParameter& param);
  virtual ~DataLayer();

  // 构建函数,重载虚函数
  virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  // DataLayer uses DataReader instead for sharing for parallelism
  virtual inline bool ShareInParallel() const { return false; }
  virtual inline const char* type() const { return "Data"; }
  virtual inline int ExactNumBottomBlobs() const { return 0; }
  virtual inline int MinTopBlobs() const { return 1; }
  virtual inline int MaxTopBlobs() const { return 2; }

 protected:
  // 加载batch,重载虚函数
  virtual void load_batch(Batch<Dtype>* batch);

  // DataReader对象
  DataReader reader_;
};

cpp文件如下,

  1. // 构造函数  
  2. template <typename Dtype>  
  3. DataLayer<Dtype>::DataLayer(const LayerParameter& param)  
  4.   : BasePrefetchingDataLayer<Dtype>(param),  
  5.     reader_(param) {  
  6. }  
  7. // 析构函数  
  8. template <typename Dtype>  
  9. DataLayer<Dtype>::~DataLayer() {  
  10.   // 终止线程  
  11.   this->StopInternalThread();  
  12. }  
  13.   
  14. template <typename Dtype>  
  15. void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  16.       const vector<Blob<Dtype>*>& top) {  
  17.   const int batch_size = this->layer_param_.data_param().batch_size();  
  18.   // 读取一个dataum,用来初始化top blob维度  
  19.   Datum& datum = *(reader_.full().peek());  
  20.   
  21.   // 从datum获取单个数据维度  
  22.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  23.   this->transformed_data_.Reshape(top_shape);  
  24.   // 加上batch尺寸  
  25.   top_shape[0] = batch_size;  
  26.   // Reshape  
  27.   top[0]->Reshape(top_shape);  
  28.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  29.     // Reshape,并分配data内存  
  30.     this->prefetch_[i].data_.Reshape(top_shape);  
  31.   }  
  32.   // 输出尺寸信息  
  33.   LOG(INFO) << ”output data size: ” << top[0]->num() << “,”  
  34.       << top[0]->channels() << ”,” << top[0]->height() << “,”  
  35.       << top[0]->width();  
  36.   // label  
  37.   if (this->output_labels_) {  
  38.     vector<int> label_shape(1, batch_size);  
  39.     top[1]->Reshape(label_shape);  
  40.     for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  41.       // Reshape,并分配label内存  
  42.       this->prefetch_[i].label_.Reshape(label_shape);  
  43.     }  
  44.   }  
  45. }  
  46.   
  47. // 该函数被InternalThreadEntry线程函数调用  
  48. template<typename Dtype>  
  49. void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  50.   CPUTimer batch_timer;  
  51.   batch_timer.Start();  
  52.   double read_time = 0;  
  53.   double trans_time = 0;  
  54.   CPUTimer timer;  
  55.   CHECK(batch->data_.count());  
  56.   CHECK(this->transformed_data_.count());  
  57.   
  58.   // 读取一个dataum,用来初始化top blob维度,同上  
  59.   const int batch_size = this->layer_param_.data_param().batch_size();  
  60.   Datum& datum = *(reader_.full().peek());  
  61.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  62.   this->transformed_data_.Reshape(top_shape);  
  63.   top_shape[0] = batch_size;  
  64.   batch->data_.Reshape(top_shape);  
  65.   
  66.   Dtype* top_data = batch->data_.mutable_cpu_data();  
  67.   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables  
  68.   
  69.   if (this->output_labels_) {  
  70.     top_label = batch->label_.mutable_cpu_data();  
  71.   }  
  72.   
  73.   // 循环加载batch  
  74.   for (int item_id = 0; item_id < batch_size; ++item_id) {  
  75.     timer.Start();  
  76.     // 读取数据datum  
  77.     Datum& datum = *(reader_.full().pop(”Waiting for data”));  
  78.     // 统计读取时间  
  79.     read_time += timer.MicroSeconds();  
  80.     timer.Start();  
  81.     // 计算指针offset  
  82.     int offset = batch->data_.offset(item_id);  
  83.     this->transformed_data_.set_cpu_data(top_data + offset);  
  84.     // 将datum数据拷贝到batch中  
  85.     this->data_transformer_->Transform(datum, &(this->transformed_data_));  
  86.     // 拷贝标签  
  87.     if (this->output_labels_) {  
  88.       top_label[item_id] = datum.label();  
  89.     }  
  90.     // 统计拷贝时间  
  91.     trans_time += timer.MicroSeconds();  
  92.   
  93.     reader_.free().push(const_cast<Datum*>(&datum));  
  94.   }  
  95.   timer.Stop();  
  96.   // 统计加载batch总时间  
  97.   batch_timer.Stop();  
  98.   // 输出时间开销  
  99.   DLOG(INFO) << ”Prefetch batch: ” << batch_timer.MilliSeconds() << “ ms.”;  
  100.   DLOG(INFO) << ”     Read time: ” << read_time / 1000 << “ ms.”;  
  101.   DLOG(INFO) << ”Transform time: ” << trans_time / 1000 << “ ms.”;  
  102. }  
// 构造函数
template <typename Dtype>
DataLayer<Dtype>::DataLayer(const LayerParameter& param)
  : BasePrefetchingDataLayer<Dtype>(param),
    reader_(param) {
}
// 析构函数
template <typename Dtype>
DataLayer<Dtype>::~DataLayer() {
  // 终止线程
  this->StopInternalThread();
}

template <typename Dtype>
void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  const int batch_size = this->layer_param_.data_param().batch_size();
  // 读取一个dataum,用来初始化top blob维度
  Datum& datum = *(reader_.full().peek());

  // 从datum获取单个数据维度
  vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
  this->transformed_data_.Reshape(top_shape);
  // 加上batch尺寸
  top_shape[0] = batch_size;
  // Reshape
  top[0]->Reshape(top_shape);
  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
    // Reshape,并分配data内存
    this->prefetch_[i].data_.Reshape(top_shape);
  }
  // 输出尺寸信息
  LOG(INFO) << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();
  // label
  if (this->output_labels_) {
    vector<int> label_shape(1, batch_size);
    top[1]->Reshape(label_shape);
    for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
      // Reshape,并分配label内存
      this->prefetch_[i].label_.Reshape(label_shape);
    }
  }
}

// 该函数被InternalThreadEntry线程函数调用
template<typename Dtype>
void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());

  // 读取一个dataum,用来初始化top blob维度,同上
  const int batch_size = this->layer_param_.data_param().batch_size();
  Datum& datum = *(reader_.full().peek());
  vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
  this->transformed_data_.Reshape(top_shape);
  top_shape[0] = batch_size;
  batch->data_.Reshape(top_shape);

  Dtype* top_data = batch->data_.mutable_cpu_data();
  Dtype* top_label = NULL;  // suppress warnings about uninitialized variables

  if (this->output_labels_) {
    top_label = batch->label_.mutable_cpu_data();
  }

  // 循环加载batch
  for (int item_id = 0; item_id < batch_size; ++item_id) {
    timer.Start();
    // 读取数据datum
    Datum& datum = *(reader_.full().pop("Waiting for data"));
    // 统计读取时间
    read_time += timer.MicroSeconds();
    timer.Start();
    // 计算指针offset
    int offset = batch->data_.offset(item_id);
    this->transformed_data_.set_cpu_data(top_data + offset);
    // 将datum数据拷贝到batch中
    this->data_transformer_->Transform(datum, &(this->transformed_data_));
    // 拷贝标签
    if (this->output_labels_) {
      top_label[item_id] = datum.label();
    }
    // 统计拷贝时间
    trans_time += timer.MicroSeconds();

    reader_.free().push(const_cast<Datum*>(&datum));
  }
  timer.Stop();
  // 统计加载batch总时间
  batch_timer.Stop();
  // 输出时间开销
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}


猜你喜欢

转载自blog.csdn.net/u011956147/article/details/77987504