Caffe之Scale层源码

简介

最近要对模型进行压缩使用slimming,因此需要scale层对scale_diff进行L1正则。所以对源码进行了阅读。
公式: y = ax + b。
公式比较简单。上述公式的意思是,对feature map乘以a,并加b。一个feature map共用一个a,因此 a的维度是 c ,这是理解源码的前提。
反向传播:
这里写图片描述

参数设置

message ScaleParameter {

  optional int32 axis = 1 [default = 1]; //默认要处理的维度
  optional int32 num_axes = 2 [default = 1];//只有一个维度
  optional FillerParameter filler = 3; //alpha可学习的,alpha初始值设置
  optional bool bias_term = 4 [default = false];//是否需要偏置项
  optional FillerParameter bias_filler = 5; //偏置项初始化
}

成员变量

  shared_ptr<Layer<Dtype> > bias_layer_; //对于偏置项的学习直接调用bias 层
  vector<Blob<Dtype>*> bias_bottom_vec_;
  vector<bool> bias_propagate_down_;
  int bias_param_id_;
  Blob<Dtype> sum_multiplier_; //保存
  Blob<Dtype> sum_result_;
  Blob<Dtype> temp_;
  int axis_;
  int outer_dim_, scale_dim_, inner_dim_;

Layer setup建立

template <typename Dtype>
void ScaleLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
  if (bias_layer_ &&
      this->param_propagate_down_[this->param_propagate_down_.size() - 1]) {
    bias_layer_->Backward(top, bias_propagate_down_, bias_bottom_vec_);
  }
  const bool scale_param = (bottom.size() == 1);
  Blob<Dtype>* scale = scale_param ? this->blobs_[0].get() : bottom[1];
  if ((!scale_param && propagate_down[1]) ||
      (scale_param && this->param_propagate_down_[0])) {
    const Dtype* top_diff = top[0]->cpu_diff();
    const bool in_place = (bottom[0] == top[0]);
    const Dtype* bottom_data = (in_place ? &temp_ : bottom[0])->cpu_data();
    // Hack: store big eltwise product in bottom[0] diff, except in the special
    // case where this layer itself does the eltwise product, in which case we
    // can store it directly in the scale diff, and we're done.
    // If we're computing in-place (and not doing eltwise computation), this
    // hack doesn't work and we store the product in temp_.
    const bool is_eltwise = (bottom[0]->count() == scale->count());
    Dtype* product = (is_eltwise ? scale->mutable_cpu_diff() :
        (in_place ? temp_.mutable_cpu_data() : bottom[0]->mutable_cpu_diff()));
    caffe_mul(top[0]->count(), top_diff, bottom_data, product);
    if (!is_eltwise) {
      Dtype* sum_result = NULL;
      if (inner_dim_ == 1) {
        sum_result = product;
      } else if (sum_result_.count() == 1) {
        const Dtype* sum_mult = sum_multiplier_.cpu_data();
        Dtype* scale_diff = scale->mutable_cpu_diff();
        if (scale_param) {
          Dtype result = caffe_cpu_dot(inner_dim_, product, sum_mult);
          *scale_diff += result;
        } else {
          *scale_diff = caffe_cpu_dot(inner_dim_, product, sum_mult);
        }
      } else {
        const Dtype* sum_mult = sum_multiplier_.cpu_data();
        sum_result = (outer_dim_ == 1) ?
            scale->mutable_cpu_diff() : sum_result_.mutable_cpu_data();
        caffe_cpu_gemv(CblasNoTrans, sum_result_.count(), inner_dim_,
                       Dtype(1), product, sum_mult, Dtype(0), sum_result);
      }
      if (outer_dim_ != 1) {
        const Dtype* sum_mult = sum_multiplier_.cpu_data();
        Dtype* scale_diff = scale->mutable_cpu_diff();
        if (scale_dim_ == 1) {
          if (scale_param) {
            Dtype result = caffe_cpu_dot(outer_dim_, sum_mult, sum_result);
            *scale_diff += result;
          } else {
            *scale_diff = caffe_cpu_dot(outer_dim_, sum_mult, sum_result);
          }
        } else {
          caffe_cpu_gemv(CblasTrans, outer_dim_, scale_dim_,
                         Dtype(1), sum_result, sum_mult, Dtype(scale_param),
                         scale_diff);
        }
      }
    }
  }
  if (propagate_down[0]) {
    const Dtype* top_diff = top[0]->cpu_diff();
    const Dtype* scale_data = scale->cpu_data();
    Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
    for (int n = 0; n < outer_dim_; ++n) {
      for (int d = 0; d < scale_dim_; ++d) {
        const Dtype factor = scale_data[d];
        caffe_cpu_scale(inner_dim_, factor, top_diff, bottom_diff);
        bottom_diff += inner_dim_;
        top_diff += inner_dim_;
      }
    }
  }
}

猜你喜欢

转载自blog.csdn.net/Iriving_shu/article/details/79177151