Use scikit-plot to visualize trained machine learning models (including multi-class ROC curves, confusion matrices, etc.)

Table of contents

1. Installation

2. Case drawing

1) Visualization of evaluation indicators

1. Confusion matrix

2. Multi-category ROC curve

3. KS Statistical Chart

4. PR curve

5. Silhouette analysis analysis

6. The correction curve of the classifier

2) Model visualization

1. Training and testing learning curves under different training samples

2. The importance of visual features

3) Cluster visualization

1. Elbow diagram of clustering

4) Dimensionality reduction visualization

1. The explained variance ratio of the PCA component

2. Scatter plot after PCA dimensionality reduction


scikit-learn (sklearn)It is a common machine learning library in the Python environment, including common classification, regression and clustering algorithms. After training the model, a common operation is to visualize the model, which needs to be Matplotlibdisplayed using .

scikit-plotIt is a library based on sklearnand Matplotlib, the main function is to visualize the trained model, the function is relatively simple and easy to understand.

1. Installation

pip install scikit-plot -i https://pypi.tuna.tsinghua.edu.cn/simple

2. Case drawing

1) Visualization of evaluation indicators

1. Confusion matrix

import scikitplot as skplt
rf = RandomForestClassifier()
rf = rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)

skplt.metrics.plot_confusion_matrix(y_test, y_pred, normalize=True)
plt.show()

2. Multi-category ROC curve

import scikitplot as skplt
nb = GaussianNB()
nb = nb.fit(X_train, y_train)
y_probas = nb.predict_proba(X_test)

skplt.metrics.plot_roc(y_test, y_probas)
plt.show()

3. KS Statistical Chart

import scikitplot as skplt
lr = LogisticRegression()
lr = lr.fit(X_train, y_train)
y_probas = lr.predict_proba(X_test)

skplt.metrics.plot_ks_statistic(y_test, y_probas)
plt.show()

4. PR curve

import scikitplot as skplt
nb = GaussianNB()
nb.fit(X_train, y_train)
y_probas = nb.predict_proba(X_test)

skplt.metrics.plot_precision_recall(y_test, y_probas)
plt.show()

5. Silhouette analysis analysis

import scikitplot as skplt
kmeans = KMeans(n_clusters=4, random_state=1)
cluster_labels = kmeans.fit_predict(X)

skplt.metrics.plot_silhouette(X, cluster_labels)
plt.show()

6. The correction curve of the classifier

import scikitplot as skplt
rf = RandomForestClassifier()
lr = LogisticRegression()
nb = GaussianNB()
svm = LinearSVC()
rf_probas = rf.fit(X_train, y_train).predict_proba(X_test)
lr_probas = lr.fit(X_train, y_train).predict_proba(X_test)
nb_probas = nb.fit(X_train, y_train).predict_proba(X_test)
svm_scores = svm.fit(X_train, y_train).decision_function(X_test)
probas_list = [rf_probas, lr_probas, nb_probas, svm_scores]
clf_names = ['Random Forest', 'Logistic Regression',
              'Gaussian Naive Bayes', 'Support Vector Machine']

skplt.metrics.plot_calibration_curve(y_test,
                                      probas_list,
                                      clf_names)
plt.show()

2) Model visualization

1. Training and testing learning curves under different training samples

import scikitplot as skplt
rf = RandomForestClassifier()

skplt.estimators.plot_learning_curve(rf, X, y)
plt.show()

2. The importance of visual features

import scikitplot as skplt
rf = RandomForestClassifier()
rf.fit(X, y)

skplt.estimators.plot_feature_importances(
     rf, feature_names=['petal length', 'petal width',
                        'sepal length', 'sepal width'])
plt.show()

3) Cluster visualization

1. Elbow diagram of clustering

import scikitplot as skplt
kmeans = KMeans(random_state=1)

skplt.cluster.plot_elbow_curve(kmeans, cluster_ranges=range(1, 30))
plt.show()

4) Dimensionality reduction visualization

1. The explained variance ratio of the PCA component

import scikitplot as skplt
pca = PCA(random_state=1)
pca.fit(X)

skplt.decomposition.plot_pca_component_variance(pca)
>plt.show()

2. Scatter plot after PCA dimensionality reduction

import scikitplot as skplt
pca = PCA(random_state=1)
pca.fit(X)

skplt.decomposition.plot_pca_2d_projection(pca, X, y)
plt.show()

Guess you like

Origin blog.csdn.net/qq_45100200/article/details/131268560