利用燕尾花数据集画出P-R曲线

利用燕尾花数据集画出P-R曲线

0 导入相关库

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm, datasets
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import train_test_split

1 导入数据集

datasets.load_iris() # 数据集

iris = datasets.load_iris()
X = iris.data
y = iris.target

2 数据预处理

标签二值化(3个类 -> 001, 010, 100)
sklearn.preprocessing.label_binarize
OneVsRestClassifier策略做铺垫

y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]

在这里插入图片描述

增加800维噪声特征
np.random.RandomState
np.random.RandomState.rand
np.c_

random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
X = np.c_[X, random_state.randn(n_samples, 200*n_features)]
X.shape

在这里插入图片描述

数据集切分(训练集&测试集)
sklearn.model_selection.train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=random_state) # 随机数

3 计算精度召回和绘制曲线

sklearn.multiclass.OneVsRestClassifier.fit().decision_function
svm.SVC

classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=random_state))
y_score = classifier.fit(X_train, y_train).decision_function(X_test) # decision_function

计算精度召回和绘制曲线
-返回的阈值(临时性名称)
sklearn.metrics.precision_recall_curve
sklearn.metrics.average_precision_score

recall = dict()
average_precision = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],  y_score[:, i])
    average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i]) 

计算微观平均曲线和面积
pd.DataFrame.ravel() # 将多维数组降为一维

precision['micro'], recall['micro'], _ = precision_recall_curve(y_test.ravel(), y_score.ravel())
average_precision["micro"] = average_precision_score(y_test, y_score, average="micro")

4 画出P-R曲线

绘制每个类的精确召回曲线

plt.clf() # clf 函数用于清除当前图像窗口
plt.plot(recall["micro"], precision["micro"],
         label='micro-average Precision-recall curve (area = {0:0.2f})'.format(average_precision["micro"]))
for i in range(n_classes):
    plt.plot(recall[i], precision[i],
             label='Precision-recall curve of class {0} (area = {1:0.2f})'.format(i, average_precision[i]))

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05]) # xlim、ylim:分别设置X、Y轴的显示范围。
plt.xlabel('Recall', fontsize=16)
plt.ylabel('Precision', fontsize=16)
plt.title('Extension of Precision-Recall curve to multi-class', fontsize=20)
plt.legend(loc="lower left") # legend 是用于设置图例的函数
plt.show()

在这里插入图片描述

发布了50 篇原创文章 · 获赞 51 · 访问量 2491

猜你喜欢

转载自blog.csdn.net/hezuijiudexiaobai/article/details/104484402
今日推荐