caffe 添加dice loss及解析

版权声明:原创不易,且看且珍惜 https://blog.csdn.net/z13653662052/article/details/80538136

文件在Github下载如下图
这里写图片描述
将.cpp和.cu文件放到/caffe/src/caffe/layers/路径下,将.hpp文件放到/caffe/include/caffe/layers/路径下,然后重新编译caffe即可使用dice loss.
使用形式:

layer {
  name: "loss"
  type: "DiceCoefLoss"
  bottom: "Deconv"
  bottom: "label"
  top: "loss"
}

解析:
loss计算方式:Loss = (2x^Ty+e)/(x^Tx+y^Ty+e)
这里x和y分别是bottom[0]和bottom[1], e是smooth term,默认为1.
这边主要讲一下dice loss在caffe里的具体实现过程,做一个推导。


我们看dice.cpp第30行开始,这块是forward_cpu, 就是实现loss的过程。看到里面大量重复使用了例如 caffe_mul, caffe_cpu_gemv…等函数,这些都是caffe里面的矩阵操作函数,可以参考https://blog.csdn.net/z13653662052/article/details/80516748
这里写图片描述

  smooth = Dtype(1.);//smooth=1
  caffe_set(dim, Dtype(1), multiplier_.mutable_cpu_data());//设置multiplier_为常数1
  caffe_set(batchsize, smooth, result_tmp_.mutable_cpu_data());//设置result_tmp_为smooth
  caffe_set(batchsize, smooth, result_.mutable_cpu_data());//设置result_为smooth
caffe_mul(bottom[0]->count(), bottom[0]->cpu_data(), bottom[0]->cpu_data(), 
                tmp_.mutable_cpu_data());

tmp_ = data[0] * data[0]

caffe_cpu_gemv(CblasNoTrans, bottom[0]->num(), bottom[0]->count(1), Dtype(1.), tmp_.cpu_data(), 
                    multiplier_.cpu_data(), Dtype(1.), result_tmp_.mutable_cpu_data());                   

result_tmp_ = 1 * tmp_ * multiplier_ + 1* result_tmp_ = tmp_ + 1 = data[0] * data[0] + 1

caffe_mul(bottom[1]->count(), bottom[1]->cpu_data(), bottom[1]->cpu_data(), 
                tmp_.mutable_cpu_data());

tmp_ = data[1] * data[1]

caffe_cpu_gemv(CblasNoTrans, bottom[1]->num(), bottom[1]->count(1), Dtype(1.), tmp_.cpu_data(), 
                    multiplier_.cpu_data(), Dtype(1.), result_tmp_.mutable_cpu_data());

result_tmp_ = 1 * tmp_ * multiplier_ + 1 * result_tmp_ = tmp_ + result_tmp_ =data[1] * data[1] + data[0] * data[0] + 1

caffe_mul(bottom[0]->count(), bottom[0]->cpu_data(), bottom[1]->cpu_data(), 
                tmp_.mutable_cpu_data());

tmp_ = data[0] * data[1]

caffe_cpu_gemv(CblasNoTrans, bottom[1]->num(), bottom[1]->count(1), Dtype(2.), tmp_.cpu_data(), 
                    multiplier_.cpu_data(), Dtype(1.), result_.mutable_cpu_data());

result_ = 2 * tmp_ * multiplier_ + 1 * result_ = 2 * data[0] * data[1] + 1

caffe_div(bottom[0]->num(), result_.cpu_data(), result_tmp_.cpu_data(), result_.mutable_cpu_data());

result_ = result_ / result_tmp_ = (2 * data[0] * data[1] + 1) / (data[1] * data[1] + data[0] * data[0] + 1)
result_就是Dice值,那么dice损失就是1-result_

猜你喜欢

转载自blog.csdn.net/z13653662052/article/details/80538136