Python包sklearn画ROC曲线和PR曲线

前言

关于ROC和PR曲线的介绍请参考:
机器学习:准确率(Precision)、召回率(Recall)、F值(F-Measure)、ROC曲线、PR曲线

参考:
Python下使用sklearn绘制ROC曲线(超详细)
Python绘图|Python绘制ROC曲线和PR曲线

源码

from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt

def draw_roc(labels, preds):
    '''
    labels: list
    preds: list
    '''
    fpr, tpr, thersholds = roc_curve(labels, preds, pos_label=1) # pos_label指定哪个标签为正样本
    roc_auc = auc(fpr, tpr)  # 计算ROC曲线下面积

    plt.figure(figsize=(10,7), dpi=300)
    plt.plot(fpr, tpr, '-', color='r', label='ROC (area=%.6f)' % (roc_auc), lw=2)
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    # plt.show()
    plt.savefig('./roc.png', dpi=300, bbox_inches='tight')

def draw_pr(labels, preds):
    '''
    labels: list
    preds: list
    '''
    precision, recall, thersholds = precision_recall_curve(labels, preds, pos_label=1) # pos_label指定哪个标签为正样本
    area = average_precision_score(labels, preds, pos_label=1)  # 计算PR曲线下面积

    plt.figure(figsize=(10,7), dpi=300)
    plt.plot(recall, precision, '-', color='r', label='PR (area=%.6f)' % (area), lw=2)
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('PR Curve')
    plt.legend(loc="lower left")
    # plt.show()
    plt.savefig('./pr.png', dpi=300, bbox_inches='tight')

猜你喜欢

转载自blog.csdn.net/qq_33757398/article/details/132321290
今日推荐