python画混淆矩阵

对于分类问题,为了直观表示各类别分类的准确性,一般使用混淆矩阵M. 

混淆矩阵M的每一行代表每个真实类(GT),每一列表示预测的类。即:Mij表示GroundTruth类别为i的所有数据中被预测为类别j的数目。

这里给出两种方法画混淆矩阵。

方法一:这里采用画图像的办法,绘制混淆矩阵的表示图。颜色越深,值越大。

# -*- coding: utf-8 -*-
# By Changxu Cheng, HUST

from __future__ import division
import  numpy as np
from skimage import io, color
from PIL import Image, ImageDraw, ImageFont
import os

def drawCM(matrix, savname):
    # Display different color for different elements
    lines, cols = matrix.shape
    sumline = matrix.sum(axis=1).reshape(lines, 1)
    ratiomat = matrix / sumline
    toplot0 = 1 - ratiomat
    toplot = toplot0.repeat(50).reshape(lines, -1).repeat(50, axis=0)
    io.imsave(savname, color.gray2rgb(toplot))
    # Draw values on every block
    image = Image.open(savname)
    draw = ImageDraw.Draw(image)
    font = ImageFont.truetype(os.path.join(os.getcwd(), "draw/ARIAL.TTF"), 15)
    for i in range(lines):
        for j in range(cols):
            dig = str(matrix[i, j])
            if i == j:
                filled = (255, 181, 197)
            else:
                filled = (46, 139, 87)
            draw.text((50 * j + 10, 50 * i + 10), dig, font=font, fill=filled)
    image.save(savname, 'jpeg')

if __name__ == "__main__":
    drawCM(np.random.randint(16, size=16).reshape(4,4), 'tmp.jpg')

注意:需要用到字体文件。代码中使用的是ARIAL.TTF。这样才可以在图中直接标注出数目。

某实验结果图如下(不是上述__name == "__main__"代码的执行结果)


方法二:利用matplotlib.pyplot.matshow画图

from __future__ import division
import  numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

def plotCM(classes, matrix, savname):
    """classes: a list of class names"""
    # Normalize by row
    matrix = matrix.astype(np.float)
    linesum = matrix.sum(1)
    linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
    matrix /= linesum
    # plot
    plt.switch_backend('agg')
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(matrix)
    fig.colorbar(cax)
    ax.xaxis.set_major_locator(MultipleLocator(1))
    ax.yaxis.set_major_locator(MultipleLocator(1))
    for i in range(matrix.shape[0]):
        ax.text(i, i, str('%.2f' % (matrix[i, i] * 100)), va='center', ha='center')
    ax.set_xticklabels([''] + classes, rotation=90)
    ax.set_yticklabels([''] + classes)
    #save
    plt.savefig(savname)

这种方法可以直接标出坐标轴的含义,比较方便。



猜你喜欢

转载自blog.csdn.net/qq_27061325/article/details/80433619