Seaborn Confusion Matrix (heatmap) 2 color schemes (correct diagonal vs wrong rest)

后端 未结 2 719
盖世英雄少女心
盖世英雄少女心 2021-01-07 10:37

Background

In a confusion matrix, the diagonal represents the cases that the predicted label matches the correct label. So the diagonal is good, while all other cel

相关标签:
2条回答
  • 2021-01-07 10:40

    You could first plot the heatmap with colormap 'OrRd' and then overlay it with a heatmap with colormap 'Blues', with the upper and lower triangle values replaced with NaN's, see the following example:

    def diagonal_heatmap(m):
    
        vmin = np.min(m)
        vmax = np.max(m)    
        
        sns.heatmap(cf_matrix, annot=True, cmap='OrRd', vmin=vmin, vmax=vmax)
    
        diag_nan = np.full_like(m, np.nan, dtype=float)
        np.fill_diagonal(diag_nan, np.diag(m))
        
        sns.heatmap(diag_nan, annot=True, cmap='Blues', vmin=vmin, vmax=vmax, cbar_kws={'ticks':[]}) 
    
    
    
    
    cf_matrix = np.array([[50, 2, 38],
                          [7, 43, 32],
                          [9,  4, 76]])
    
    diagonal_heatmap(cf_matrix)
    
    0 讨论(0)
  • 2021-01-07 11:00

    You can use mask= in the call to heatmap() to choose which cells to show. Using two different masks for the diagonal and the off_diagonal cells, you can get the desired output:

    import numpy as np
    import seaborn as sns
    
    cf_matrix = np.array([[50, 2, 38],
                          [7, 43, 32],
                          [9,  4, 76]])
    
    vmin = np.min(cf_matrix)
    vmax = np.max(cf_matrix)
    off_diag_mask = np.eye(*cf_matrix.shape, dtype=bool)
    
    fig = plt.figure()
    sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax)
    sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, cbar_kws=dict(ticks=[]))
    

    If you want to get fancy, you can create the axes using GridSpec to have a better layout:

    import numpy as np import seaborn as sns

    fig = plt.figure()
    gs0 = matplotlib.gridspec.GridSpec(1,2, width_ratios=[20,2], hspace=0.05)
    gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(1,2, subplot_spec=gs0[1], hspace=0)
    
    ax = fig.add_subplot(gs0[0])
    cax1 = fig.add_subplot(gs00[0])
    cax2 = fig.add_subplot(gs00[1])
    
    sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax2)
    sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax1, cbar_kws=dict(ticks=[]))
    

    0 讨论(0)
提交回复
热议问题