使用keras的fit_generator来获得混淆矩阵Confusion Matrix

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xfjjs_net/article/details/84798470

还是google过来的方法,说明它还是挺靠谱滴。这里有必要记录一下。

关于混乱淆矩阵是用来干嘛的,这里就不说了。各位可以百度or谷歌。

关于如何使用fit_generator来进行训练可以看我上一篇文章。

我们在使用fit_generator方法来进行训练的时候,是不需要自己读取x_img_train,y_label_train的。都是generator帮我们做好了。

但是要想画混淆矩阵的话得要 验证集中原始图像的标签与预测到的标签值。

model.fit下面有对应的evaluate,predict 等,那么model.fit_generator下面自然也有对应的,且看:

fit_generator

evaluate_generator:

evaluate_generator(generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)

其它很多参数我们暂时用不到,可以写成(关于validation_generator可以去看我上一篇文章):

model.evaluate_generator(validation_generator,verbose=1)

predict_generator:

predict_generator(generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)

也可以直接写成:

prediction=model.predict_generator(validation_generator,verbose=1)

接下来我们对prediction做一下处理

#因为prediction是一个n行5列的数组,我们要把它转换成一维数组
#这样的话每个值就会与验证集中的标签值一一对应上了。
predict_label=np.argmax(prediction,axis=1)

验证集中真实数据的标签为:

true_label=validation_generator.classes

好了,混淆矩阵中所需要的两个参数我们都已经得到了。

1:true_label  真实数据标签

2:predict_label  预测的数据标签

接下来使用pd.crosstab来简单画出混淆矩阵

import pandas as pd
pd.crosstab(true_label,predict_label,rownames=['label'],colnames=['predict'])

下面是操作的注意事项:

我这里的 validation_generator 是没有被shuffle的。这样的话正好与后面的真实标签跟预测标签一一对应。

如果前面被shuffle的话,这边肯定就对不上了。

不知道网友们有没有什么更好的方法来把这两种标签相互映射?

文章是对着我自己的项目写的,具体数据这里没给出。如果各位看不懂的话可以留言,我把代码完整贴上......

猜你喜欢

转载自blog.csdn.net/xfjjs_net/article/details/84798470