Annotated heatmap with multiple color schemes

时光怂恿深爱的人放手 提交于 2021-02-07 04:18:46


I have the following dataframe and would like to differentiate the minor decimal differences in each "step" with a different color scheme in a heatmap.

Sample data:

Sample  Step 2  Step 3  Step 4  Step 5  Step 6  Step 7  Step 8
A   64.847  54.821  20.897  39.733  23.257  74.942  75.945
B   64.885  54.767  20.828  39.613  23.093  74.963  75.928
C   65.036  54.772  20.939  39.835  23.283  74.944  75.871
D   64.869  54.740  21.039  39.889  23.322  74.925  75.894
E   64.911  54.730  20.858  39.608  23.101  74.956  75.930
F   64.838  54.749  20.707  39.394  22.984  74.929  75.941
G   64.887  54.781  20.948  39.748  23.238  74.957  75.909
H   64.903  54.720  20.783  39.540  23.028  74.898  75.911
I   64.875  54.761  20.911  39.695  23.082  74.897  75.866
J   64.839  54.717  20.692  39.377  22.853  74.849  75.939
K   64.857  54.736  20.934  39.699  23.130  74.880  75.903
L   64.754  54.746  20.777  39.536  22.991  74.877  75.902
M   64.798  54.811  20.963  39.824  23.187  74.886  75.895

An example of what I am looking for:


My first approach would be based on a figure with multiple subplots. Number of plots would equal number of columns in your dataframe; the gap between the plots could be shrinked down to zero:

cm = ['Blues', 'Reds', 'Greens', 'Oranges', 'Purples', 'bone', 'winter']
f, axs = plt.subplots(1, df.columns.size, gridspec_kw={'wspace': 0})
for i, (s, a, c) in enumerate(zip(df.columns, axs, cm)):
    sns.heatmap(np.array([df[s].values]).T, yticklabels=df.index, xticklabels=[s], annot=True, fmt='.2f', ax=a, cmap=c, cbar=False)
    if i>0:


Not sure if this will lead to a helpful or even self describing visualization of data, but that's your choice - perhaps this helps to start...


Regarding adding the colorbars: of course you can. But - besides not knowing the background of your data and the purpose of the visualization - I'd like to add some thoughts on all that:

First: adding all those colorbars as a separate bunch of bars on one side or below the heatmap is probably possible, but I find it already quite hard to read the data, plus: you already have all those annotations - it would mess all up I think.
Additionally: in the meantime @ImportanceOfBeingErnest provided such a beutiful solution on that topic, that this would be not too meaningful imo here.

Second: if you really want to stick to the heatmap thing, perhaps splitting up and giving every column its colorbar would suit better:

cm = ['Blues', 'Reds', 'Greens', 'Oranges', 'Purples', 'bone', 'winter']
f, axs = plt.subplots(1, df.columns.size, figsize=(10, 3))
for i, (s, a, c) in enumerate(zip(df.columns, axs, cm)):
    sns.heatmap(np.array([df[s].values]).T, yticklabels=df.index, xticklabels=[s], annot=True, fmt='.2f', ax=a, cmap=c)
    if i>0:

However, all that said - I dare to doubt that this is the best visualization for your data. Of course, I don't know what you want to say, see or find with these plots, but that's the point: if the visualization type would fit to the needs, I guess I'd know (or at least could imagine).

Just for example:
A simple df.plot() results in

and I feel that this tells more about different characteristics of your columns within some tenths of a second than the heatmap.

Or are you explicitely after the differences to each columns' means?

(df - df.mean()).plot()

... or the distribution of each column around them?

(df - df.mean()).boxplot()

What I want to say: data visualization becomes powerful when a plot begins to tell sth about the underlying data before you begin/have to explain anything...


I suppose the problem can be divided into several parts.

Getting several heatmaps with different colormaps into the same picture. This can be done masking the complete array column-wise, plot each masked array seperately via imshow and apply a different colormap. To visualize the concept:

Obtaining variable number of distinct colormaps. Matplotlib provides a large number of colormaps, however, they are in general very different concerning luminosity and saturation. Here it seems desireable to have colormaps of differing hue, but otherwise same saturation and luminosity.
An option is to create the colormaps on the fly, choosing n different (and equally spaced) hues, and create a colormap using the same saturation and luminosity.

Obtaining a distinct colorbar for each column. Since the values within columns might be on totally different scales, a colorbar for each column would be needed to know the values shown, e.g. in the first column the brightest color may correspond to a value of 1, while in the second column it may correspond to a value of 100. Several colorbars can be created inside of the axes of a GridSpec which is placed next to the actual heatmap axes. The number of columns and rows of that gridspec would be dependent of the number of columns in the dataframe.

In total this may then look as follows.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec

def get_hsvcmap(i, N, rot=0.):
    nsc = 24
    chsv = mcolors.rgb_to_hsv( % 1.)[i,:3])
    rhsv = mcolors.rgb_to_hsv(,1,nsc))[:,:3])
    arhsv = np.tile(chsv,nsc).reshape(nsc,3)
    arhsv[:,1:] = rhsv[:,1:]
    rgb = mcolors.hsv_to_rgb(arhsv)
    return mcolors.LinearSegmentedColormap.from_list("",rgb)

def columnwise_heatmap(array, ax=None, **kw):
    ax = ax or plt.gca()
    premask = np.tile(np.arange(array.shape[1]), array.shape[0]).reshape(array.shape)
    images = []
    for i in range(array.shape[1]):
        col =, mask = premask != i)
        im = ax.imshow(col, cmap=get_hsvcmap(i, array.shape[1], rot=0.5), **kw)
    return images

### Create some dataset
ind = list("ABCDEFGHIJKLM")
m = len(ind)
n = 8
df = pd.DataFrame(np.random.randn(m,n) + np.random.randint(20,70,n), 
                  index=ind, columns=[f"Step {i}" for i in range(2,2+n)])

### Plot data
fig, ax = plt.subplots(figsize=(8,4.5))

ims = columnwise_heatmap(df.values, ax=ax, aspect="auto")

ax.set(xticks=np.arange(len(df.columns)), yticks=np.arange(len(df)),
       xticklabels=df.columns, yticklabels=df.index)
ax.tick_params(bottom=False, top=False, 
               labelbottom=False, labeltop=True, left=False)

### Optionally add colorbars.
fig.subplots_adjust(left=0.06, right=0.65)
rows = 3
cols = len(df.columns) // rows + int(len(df.columns)%rows > 0)
gs = GridSpec(rows, cols)
gs.update(left=0.7, right=0.95, wspace=1, hspace=0.3)
for i, im in enumerate(ims):
    cax = fig.add_subplot(gs[i//cols, i % cols])
    fig.colorbar(im, cax = cax)
    cax.set_title(df.columns[i], fontsize=10)

