MXNet学习笔记——5 multi-task任务实战

写在前面

本阶段目标

具体笔记

Multi-task和multi-label的区别

MXnet下定义multi-task的网络结构

定义Multi-task评价指标metric

单任务的accuracy函数:

官方的用于图像多标签(multi-label)分类的multi_accuracy函数:

cross-rentropy

recall

precision


写在前面

本系列博客记录了作者上手MXNet的全过程。作者在接触MXNet之前主要使用keras,和一点tensorflow,因此在上手MXNet之前有一点deep learning的项目基础。主要参考资料为MXNet官方教程,也阅读了一些有价值的博客。

博客结构为:先列出作者对于该阶段的期望目标,以及各目标完成过程中的笔记(仅记下个人认为重要的),再附上学习过程中自己的提问(solved & unsolved,天马行空的提问,欢迎讨论)。


本阶段目标

任务 优先级 预计花时间 完成状态 遇到问题 补充
定义Multi-task数据格式 P0 2hour \checkmark    
定义Multi-task网络 P1 0.5hour \checkmark    
定义Multi-task评价指标metric P2 1.5hour \checkmark    
网络训练以及评估          

具体笔记

Multi-task和multi-label的区别

  • multi-task 比multi-label更复杂,网络的中间过程可以有分支

  • multi-label是特殊的multi-task。当每个task的分类取值都是二分类时,就是multi-label,但multi-task的每个任务可以是多分类

MXnet下定义multi-task的网络结构

  • 代码:

  • 图示:
网络结构在flatten0后出现了分支

定义Multi-task评价指标metric

网络上关于multi-task的metric资料很多,但基本都是multi_accuracy,在此整理了accuracy / cross-entropy / precision / recall 的单(多)任务版本。

  • 单任务的accuracy函数:

import mxnet as mx
   class Accuracy(mx.metric.EvalMetric):
       def __init__(self, num=None):
           super(Accuracy, self).__init__('accuracy', num)
    
       def update(self, labels, preds):
           pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
           label = labels[0].asnumpy().astype('int32')
    
           mx.metric.check_label_shapes(label, pred_label)
    
           self.sum_metric += (pred_label.flat == label.flat).sum()
           self.num_inst += len(pred_label.flat)
  • 官方的用于图像多标签(multi-label)分类的multi_accuracy函数:

class Multi_Accuracy(mx.metric.EvalMetric):
    """Calculate accuracies of multi label"""
 
    def __init__(self, num=None):
        self.num = num
 
        super(Multi_Accuracy, self).__init__('multi-accuracy')
 
    def reset(self):
        """Resets the internal evaluation result to initial state."""
        self.num_inst = 0 if self.num is None else [0] * self.num
        self.sum_metric = 0.0 if self.num is None else [0.0] * self.num
 
    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
 
        if self.num is not None:
            assert len(labels) == self.num
 
        for i in range(len(labels)):
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[i].asnumpy().astype('int32')
 
            mx.metric.check_label_shapes(label, pred_label)
 
            if self.num is None:
                self.sum_metric += (pred_label.flat == label.flat).sum()
                self.num_inst += len(pred_label.flat)
            else:
                self.sum_metric[i] += (pred_label.flat == label.flat).sum()
                self.num_inst[i] += len(pred_label.flat)
 
    def get(self):
        """Gets the current evaluation result.
 
        Returns
        -------
        names : list of str
           Name of the metrics.
        values : list of float
           Value of the evaluations.
        """
        if self.num is None:
            return super(Multi_Accuracy, self).get()
        else:
            return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0 else self.sum_metric[i] / self.num_inst[i]) for i in range(self.num)))
 
    def get_name_value(self):
        """Returns zipped name and value pairs.
 
        Returns
        -------
        list of tuples
            A (name, value) tuple list.
        """
        if self.num is None:
            return super(Multi_Accuracy, self).get_name_value()
        name, value = self.get()
        return list(zip(name, value))

调用时,修改Multi_Accuracy(num=3)的参数num,就可以指定计算出几个accuracy。

from my_metric import *
   eval_metric = mx.metric.CompositeEvalMetric()
   eval_metric.add(Multi_Accuracy(num=2))

以下的 cross-entropy / recall 和 precision 的metric函数,均可通过修改num和name来指定用于单任务还是多任务。

  • cross-rentropy

class CrossEntropy(mx.metric.EvalMetric):
    def __init__(self, eps=1e-12, name='cross-entropy',
                 output_names=None, label_names=None, num=None):
        super(CrossEntropy, self).__init__(
            name, eps=eps,
            output_names=output_names, label_names=label_names)
        self.eps = eps
        self.num = num
        self.name = name
        self.reset()

    def reset(self):
        if getattr(self, 'num', None) is None:
            self.num_inst = 0
            self.sum_metric = 0.0
        else:
            self.num_inst = [0] * self.num
            self.sum_metric = [0.0] * self.num

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)

        i = 0
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            label = label.ravel()
            assert label.shape[0] == pred.shape[0]

            if i == 1:
                sexy_index = np.where(np.int64(label) == -1)
                label[sexy_index] = 0.0 # random 0 or 1
                pred[sexy_index] = np.ones((len(sexy_index),2)) # No loss for sexy image
            prob = pred[np.arange(label.shape[0]), np.int64(label)]
            if self.num is None:
                self.sum_metric += (-np.log(prob + self.eps)).sum()
                if i == 1:
                    self.num_inst += (label.shape[0] - len(sexy_index[0]))
                else:
                    self.num_inst += label.shape[0]
            else:
                self.sum_metric[i] += (-np.log(prob + self.eps)).sum()
                if i == 1:
                    self.num_inst[i] += (label.shape[0] - len(sexy_index[0]))
                else:
                    self.num_inst[i] += label.shape[0]
                i += 1

    def get(self):
        if self.num is None:
            if self.num_inst == 0:
                return (self.name, float('nan'))
            else:
                return (self.name, self.sum_metric / self.num_inst)
        else:
            result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
            return (self.name, result)
  • recall

class Recall(mx.metric.EvalMetric):
    def __init__(self, name, num=None):
        super(Recall, self).__init__('Recall')
        self.num = num
        self.name = name
        self.reset()

    def reset(self):
        if getattr(self, 'num', None) is None:
            self.num_inst = 0
            self.sum_metric = 0.0
        else:
            self.num_inst = [0] * self.num
            self.sum_metric = [0.0] * self.num

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
        i = 0
        for pred, label in zip(preds, labels):
            pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32')
            label = label.asnumpy().astype('int32')

            count_pred = 0
            count_truth = 0
            for index in range(len(pred.flat)):
                if label[index] == -1:
                    continue
                if pred[index] == 0 and label[index] == 0:
                    count_pred += 1
                if label[index] == 0:
                    count_truth += 1
            if self.num is None:
                self.sum_metric += count_pred
                self.num_inst += count_truth
            else:
                self.sum_metric[i] += count_pred
                self.num_inst[i] += count_truth
                i += 1

    def get(self):
        if self.num is None:
            if self.num_inst == 0:
                return (self.name, float('nan'))
            else:
                return (self.name, self.sum_metric / self.num_inst)
        else:
            result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
            return (self.name, result)
  • precision

class Precision(mx.metric.EvalMetric):
    def __init__(self, name, num=None):
        super(Precision, self).__init__('Precision')
        self.num = num
        self.name = name
        self.reset()

    def reset(self):
        if getattr(self, 'num', None) is None:
            self.num_inst = 0
            self.sum_metric = 0.0
        else:
            self.num_inst = [0] * self.num
            self.sum_metric = [0.0] * self.num

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
        i = 0
        for pred, label in zip(preds, labels):
            pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32')
            label = label.asnumpy().astype('int32')

            count_truth = 0
            count_pred = 0
            for index in range(len(pred.flat)):
                if label[index] == -1:
                    continue
                if pred[index] == 0 and label[index] == 0:
                    count_truth +=1
                if pred[index] ==0:
                    count_pred +=1
            
            if self.num is None:
                self.sum_metric += count_truth
                self.num_inst += count_pred
            else:
                self.sum_metric[i] += count_truth
                self.num_inst[i] += count_pred
                i += 1

    def get(self):
        if self.num is None:
            if self.num_inst == 0:
                return (self.name, float('nan'))
            else:
                return (self.name, self.sum_metric / self.num_inst)
        else:
            result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
            return (self.name, result)

猜你喜欢

转载自blog.csdn.net/s000da/article/details/90438701