使用python绘制混淆矩阵(可直接复制调用)
转自https://www.jianshu.com/p/13debf42fdb7
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
def cm_plot(original_label, predict_label, pic=None):
cm = confusion_matrix(original_label, predict_label) # 由原标签和预测标签生成混淆矩阵
plt.figure()
plt.matshow(cm, cmap=plt.cm.Blues) # 画混淆矩阵,配色风格使用cm.Blues
plt.colorbar() # 颜色标签
for x in range(len(cm)):
for y in range(len(cm)):
plt.annotate(cm[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')
# annotate主要在图形中添加注释
# 第一个参数添加注释
# 第二个参数是注释的内容
# xy设置箭头尖的坐标
# horizontalalignment水平对齐
# verticalalignment垂直对齐
# 其余常用参数如下:
# xytext设置注释内容显示的起始位置
# arrowprops 用来设置箭头
# facecolor 设置箭头的颜色
# headlength 箭头的头的长度
# headwidth 箭头的宽度
# width 箭身的宽度
plt.ylabel('True label') # 坐标轴标签
plt.xlabel('Predicted label') # 坐标轴标签
plt.title('confusion matrix')
if pic is not None:
plt.savefig(str(pic) + '.jpg')
plt.show()