python画混淆矩阵(confusion matrix)

匿名 (未验证) 提交于 2019-12-02 22:51:30

混淆矩阵(Confusion Matrix),是一种在深度学习中常用的辅助工具,可以让你直观地了解你的模型在哪一类样本里面表现得不是很好。

如上图,我们就可以看到,有一个样本原本是0的,却被预测成了1,还有一个,原本是2的,却被预测成了0。

简单介绍作用后,下面上代码:

import seaborn as sns from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt

导入需要的包,如果有一些包没有,pip一下就可以了。

sns.set() f,ax=plt.subplots() y_true = [0,0,1,2,1,2,0,2,2,0,1,1] y_pred = [1,0,1,2,1,0,0,2,2,0,1,1] C2= confusion_matrix(y_true, y_pred, labels=[0, 1, 2]) print(C2) #打印出来看看 sns.heatmap(C2,annot=True,ax=ax) #画热力图  ax.set_title('confusion matrix') #标题 ax.set_xlabel('predict') #x轴 ax.set_ylabel('true') #y轴

下面就是结果:

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!