focal loss和dmi loss的keras(tf.keras)实现

话不多说

focal loss

原文
主要解决分类问题中类别不平衡、分类难度差异的问题

损失函数实现:

from tensorflow.keras import backend as K
def categorical_focal_loss_fixed(y_true, y_pred):
            """
            :param y_true: A tensor of the same shape as `y_pred`
            :param y_pred: A tensor resulting from a softmax
            :return: Output tensor.
            """

            # Scale predictions so that the class probas of each sample sum to 1
            y_pred /= K.sum(y_pred, axis=-1, keepdims=True)

            # Clip the prediction value to prevent NaN's and Inf's
            epsilon = K.epsilon()
            y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

            # Calculate Cross Entropy
            cross_entropy = -y_true * K.log(y_pred)

            # Calculate Focal Loss
            loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy

            # Compute mean loss in mini_batch
            return K.mean(loss, axis=1)

DMI loss

介绍文章

实现:

import tensorflow as tf
def dmi_loss(y_true, y_pred):
            """
            y_true为onehot真实标签
            y_pred为softmax后分数
            """
            y_true = tf.transpose(y_true, perm=[1, 0])
            mat = tf.matmul(y_true, y_pred)
            loss = -1.0 * tf.math.log(tf.math.abs(tf.linalg.det(mat)) + 0.001)
            return loss

损失函数可能跑到负值,但是work。

但是dmi也会导致其对高precision下召回的负面影响。

使用交叉熵损失:
使用dmi loss
使用dmi loss:
在这里插入图片描述
尽管两者的准确度相差不大,但是明显交叉熵的pr曲线比dmi更好一些。

下面是公众号,欢迎扫描二维码,谢谢关注,谢谢支持!

公众号名称: Python入坑NLP
公众号
本公众号主要致力于自然语言处理、机器学习、coding算法以及Python的一些知识分享。本人只是小菜,希望记录自己学习、工作过程的同时,大家一起进步。欢迎交流、分享。

猜你喜欢

转载自blog.csdn.net/lovoslbdy/article/details/107169797