混淆矩阵M的每一行代表每个真实类(GT),每一列表示预测的类。即:Mij表示GroundTruth类别为i的所有数据中被预测为类别j的数目。
这里采用画图像的办法,绘制混淆矩阵的表示图。颜色越深,值越大。
# -*- coding: utf-8 -*- # By Changxu Cheng, HUST from __future__ import division import numpy as np from skimage import io, color from PIL import Image, ImageDraw, ImageFont import os def drawCM(matrix, savname): # Display different color for different elements lines, cols = matrix.shape sumline = matrix.sum(axis=1).reshape(lines, 1) ratiomat = matrix / sumline toplot0 = 1 - ratiomat toplot = toplot0.repeat(50).reshape(lines, -1).repeat(50, axis=0) io.imsave(savname, color.gray2rgb(toplot)) # Draw values on every block image = Image.open(savname) draw = ImageDraw.Draw(image) font = ImageFont.truetype(os.path.join(os.getcwd(), "draw/ARIAL.TTF"), 15) for i in range(lines): for j in range(cols): dig = str(matrix[i, j]) if i == j: filled = (255, 181, 197) else: filled = (46, 139, 87) draw.text((50 * j + 10, 50 * i + 10), dig, font=font, fill=filled) image.save(savname, 'jpeg') if __name__ == "__main__": drawCM(np.random.randint(16, size=16).reshape(4,4), 'tmp.jpg')
注意:需要用到字体文件。代码中使用的是ARIAL.TTF。这样才可以在图中直接标注出数目。
某实验结果图如下(不是上述__name == "__main__"代码的执行结果)
文章来源: python画混淆矩阵