问题
Background
In a confusion matrix, the diagonal represents the cases that the predicted label matches the correct label. So the diagonal is good, while all other cells are bad. To clarify what is good and what is bad in a CM for non-experts, I want to give the diagonal a different color than the rest. I want to achieve this with Python & Seaborn.
Basically I'm trying to achieve what this question does in R (ggplot2 Heatmap 2 Different Color Schemes - Confusion Matrix: Matches in Different Color Scheme than Missclassifications)
Normal Seaborn Confusion Matrix with heatmap
import numpy as np
import seaborn as sns
cf_matrix = np.array([[50, 2, 38],
[7, 43, 32],
[9, 4, 76]])
sns.heatmap(cf_matrix, annot=True, cmap='Blues') # cmap='OrRd'
Which results in this image:
Goal
I would like to color the non-diagonal cells with e.g. cmap='OrRd'
. So I imagine there would be 2 colorbars, 1 blue for the diagonal and 1 for the other cells. Preferably the values of both colorbars match (so both e.g. 0-70 and not 0-70 and 0-40).
How would I approach this?
The following is not made with code, but with photo editing software:
回答1:
You can use mask=
in the call to heatmap()
to choose which cells to show. Using two different masks for the diagonal and the off_diagonal cells, you can get the desired output:
import numpy as np
import seaborn as sns
cf_matrix = np.array([[50, 2, 38],
[7, 43, 32],
[9, 4, 76]])
vmin = np.min(cf_matrix)
vmax = np.max(cf_matrix)
off_diag_mask = np.eye(*cf_matrix.shape, dtype=bool)
fig = plt.figure()
sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, cbar_kws=dict(ticks=[]))
If you want to get fancy, you can create the axes using GridSpec to have a better layout:
import numpy as np import seaborn as sns
fig = plt.figure()
gs0 = matplotlib.gridspec.GridSpec(1,2, width_ratios=[20,2], hspace=0.05)
gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(1,2, subplot_spec=gs0[1], hspace=0)
ax = fig.add_subplot(gs0[0])
cax1 = fig.add_subplot(gs00[0])
cax2 = fig.add_subplot(gs00[1])
sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax2)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax1, cbar_kws=dict(ticks=[]))
回答2:
You could first plot the heatmap with colormap 'OrRd' and then overlay it with a heatmap with colormap 'Blues', with the upper and lower triangle values replaced with NaN's, see the following example:
def diagonal_heatmap(m):
vmin = np.min(m)
vmax = np.max(m)
sns.heatmap(cf_matrix, annot=True, cmap='OrRd', vmin=vmin, vmax=vmax)
diag_nan = np.full_like(m, np.nan, dtype=float)
np.fill_diagonal(diag_nan, np.diag(m))
sns.heatmap(diag_nan, annot=True, cmap='Blues', vmin=vmin, vmax=vmax, cbar_kws={'ticks':[]})
cf_matrix = np.array([[50, 2, 38],
[7, 43, 32],
[9, 4, 76]])
diagonal_heatmap(cf_matrix)
来源:https://stackoverflow.com/questions/64800003/seaborn-confusion-matrix-heatmap-2-color-schemes-correct-diagonal-vs-wrong-re