python画混淆矩阵

匿名 (未验证) 提交于 2019-12-02 22:51:30

混淆矩阵M的每一行代表每个真实类(GT),每一列表示预测的类。即:Mij表示GroundTruth类别为i的所有数据中被预测为类别j的数目。

这里采用画图像的办法,绘制混淆矩阵的表示图。颜色越深,值越大。

# -*- coding: utf-8 -*- # By Changxu Cheng, HUST  from __future__ import division import  numpy as np from skimage import io, color from PIL import Image, ImageDraw, ImageFont import os  def drawCM(matrix, savname):     # Display different color for different elements     lines, cols = matrix.shape     sumline = matrix.sum(axis=1).reshape(lines, 1)     ratiomat = matrix / sumline     toplot0 = 1 - ratiomat     toplot = toplot0.repeat(50).reshape(lines, -1).repeat(50, axis=0)     io.imsave(savname, color.gray2rgb(toplot))     # Draw values on every block     image = Image.open(savname)     draw = ImageDraw.Draw(image)     font = ImageFont.truetype(os.path.join(os.getcwd(), "draw/ARIAL.TTF"), 15)     for i in range(lines):         for j in range(cols):             dig = str(matrix[i, j])             if i == j:                 filled = (255, 181, 197)             else:                 filled = (46, 139, 87)             draw.text((50 * j + 10, 50 * i + 10), dig, font=font, fill=filled)     image.save(savname, 'jpeg')  if __name__ == "__main__":     drawCM(np.random.randint(16, size=16).reshape(4,4), 'tmp.jpg')

注意:需要用到字体文件。代码中使用的是ARIAL.TTF。这样才可以在图中直接标注出数目。

某实验结果图如下(不是上述__name == "__main__"代码的执行结果)


文章来源: python画混淆矩阵
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!