人工智能学习07--pytorch11--分类网络:使用pytorch和tensorflow计算分类模型的混淆矩阵

师兄说学目标检测之前先学分类
坏了,内容好多!学学学
感谢up主,好人一生平安
在这里插入图片描述

混淆矩阵

  1. 什么是混淆矩阵
    在这里插入图片描述
    横坐标:每一列属于该类的所有验证样本。每一列所有元素对应真实类别。
    在这里插入图片描述
    纵坐标:网络的预测类别。每一行对应预测结果属于该类的所有样本。
    在这里插入图片描述
    对角线:预测正确的样本个数。
    在这里插入图片描述
    预测值在对角线上分布的越密集,模型的性能就越好。
    还能通过混淆矩阵看到这个网络对哪些类别更容易分类出错。

  2. 混淆矩阵的指标
    在这里插入图片描述
    精确率precision不等于准确率accuracy!!
    准确率:所有预测正确的样本个数 / 所有用于验证的样本个数
    (对角线上所有数据之和 / 混淆矩阵所有数据之和)

  3. 二分类简单示例
    在这里插入图片描述
    每一列:预测值标签;
    每一列:真实值标签。

TP、TN 都代表网络预测正确的部分。(越大越好)
FP、FN 都代表网络预测错误的部分。(越小越好)

  1. 准确率、精确率、灵敏度/召回率、特异度
    在这里插入图片描述
    准确率:对所有类别的统计
    精确率、灵敏度/召回率、特异度:针对某个类别
  2. 实例
    在这里插入图片描述
    以猫为例,可以把狗和猪的类别混在一起,统一整合为不为猫的情况。得到混淆矩阵:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
  3. 参考博文
    https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839

计算混淆矩阵与相关指标

这里使用numpy进行统一计算(可以同时在TensorFlow和pytorch中使用):
在这里插入图片描述

  1. 制定一个类:ConfusionMatrix
    若图像显示不正常,就升级matplotlib
    prettytable:将输出展示成列表的形式
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

  2. 初始化函数:init
    在这里插入图片描述
    1:传入了两个变量:分类网络的分类类别个数num_classes、分类标签列表:labels。
    2:初始化一个行数和列数相等且均为num_classes的正方形的、值为零的矩阵。
    3:将num_classes赋值给类变量num_classes。

  3. 第一个实现的类:update
    在这里插入图片描述
    1:预测值preds、真实标签labels
    2:累加到混淆矩阵中。将预测、真实标签打包组合,进行遍历。
    p:预测值;t:真实类别标签
    3:矩阵[预测值(行),真实值(列)],[第t行,第p列]

  4. 第一个实现的类:summary
    统计计算各个指标
    ①准确率:
    在这里插入图片描述
    遍历,0~num_classes-1,
    统计对角线上的元素和,
    计算acc值
    ②:计算每个类别的精确率、召回率、特异度
    在这里插入图片描述
    (库)prettytable:将输出展示成列表的形式
    初始化一张表table,
    在第一行添加一些描述信息,
    使用for i in range遍历每一个类别,
    对于第i个类别,TP(true positive):对角线上元素m[i],
    FP(false positive):这一行的所有元素之和(第i行)-TP,
    FN(false negative):这一列的和(第i列)-TP
    在这里插入图片描述
    round:小数部分只取三位。
    在这里插入图片描述
    将当前类别信息添加到刚刚初始化的table里面在这里插入图片描述
    类别标签,precision,recall,specificity

  5. 绘制混淆矩阵:plot
    在这里插入图片描述
    1 将matrix赋值给matrix
    2 打印混淆矩阵
    3 使用imshow函数展示混淆矩阵。颜色变换:从白色到蓝色。
    在这里插入图片描述
    4 对于label:默认是0、1、2、3这种坐标。但是希望它展示的是标签的类别。
    使用xticks,将原来x轴的信息(0~num_classes-1)替换成为labels,对x轴旋转45°
    5 y轴同理
    6 混淆矩阵右侧像色谱一样的colorbar。数值的密集程度,颜色越深,数值就越密集。
    在这里插入图片描述
    7 横坐标 True labels
    8 纵坐标 predicted labels
    9 图像标题 Confusion matrix

将每个区域的数值标注在图像上
在这里插入图片描述
1 设置阈值,指定数字文本的颜色。取matrix最大数值的一半
2 遍历x坐标(显示图像的时候,坐标原点一般在图像的左上角),x从左到右,y从上到下。
3 遍历y坐标
4 对每一个坐标,获取它的matrix信息**[y,x]!!!不是[x,y]!!!**。 取整,得到当前位置的统计个数
在这里插入图片描述
5 通过text方法,将info绘制在[x,y]坐标处。
6、7 绘制在水平方向、竖直方向的中心位置处。
8 color对应数字的颜色,大于阈值:白色
9 让图形显示更加紧凑,否则部分信息可能被遮挡
10 展示混淆矩阵

使用pytorch计算分类模型的混淆矩阵

在这里插入图片描述
1 判断设备,是否使用GPU
2 打印设备信息
3 之前训练目标net网络时,使用针对验证集的一个处理方式,直接使用了当时已经处理好的模型权重,所以此处要与它保持相同的预处理方式。
5 使用花分类数据集的验证集
6 dataloader载入验证集
在这里插入图片描述
1 实例化网络MobileNetV2(之前写过的,这里拿来继续用)()
在这里插入图片描述
2 、3 载入之前已经训练好的MobileNetV2的模型权重
4 将模型连到设备上去
在这里插入图片描述
在这里插入图片描述
1、2 载入之前生成的json文件(对应着索引与类别信息)。读入后提取出所有的标签信息。载入后是字典形式,而我们只需要它的标签
在这里插入图片描述
4 label for_,label in class_indect.items():不要key,只要value
5 实例化上面定义的ConfusionMatrix类
6 启动验证模式
7 上下文管理器 torch.no_grad(),停止pytorch对变量梯度的跟踪
8 遍历dataloader数据集
9 分为图片,标签
10 把图片储存到设备中,输入网络,得到输出
11 softmax处理
12 通过argmax得到最大的元素
13 调用.update()方法输入预测值(outputs.numpy())、真实标签值(val_labels.numpy())
14 plot 绘制混淆矩阵
15 打印各个指标信息
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/AMWICD/article/details/129443938