How to express classes on the axis of a heatmap in Seaborn

前端 未结 3 1176
走了就别回头了
走了就别回头了 2020-12-01 02:13

I created a very simple heatmap chart with Seaborn displaying a similarity square matrix. Here is the one line of code I used:

sns.heatmap(sim_mat, linewidth         


        
相关标签:
3条回答
  • 2020-12-01 02:53

    Building on the above answer, I think it's worth noting the possibility of multiple colour levels for labels - as noted in the clustermap docs ({row,col}_colors). I couldn't find an example of multiple levels, so I thought I'd share an example here.

    networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
    

    network level

    network_labels = networks.columns.get_level_values("network")
    network_pal = sns.cubehelix_palette(network_labels.unique().size, light=.9, dark=.1, reverse=True, start=1, rot=-2)
    network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
    

    Create index using the columns for networks

    network_colors = pd.Series(network_labels, index=networks.columns).map(network_lut)
    

    node level

    node_labels = networks.columns.get_level_values("node")
    node_pal = sns.cubehelix_palette(node_labels.unique().size)
    node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
    

    Create index using the columns for nodes

    node_colors = pd.Series(node_labels, index=networks.columns).map(node_lut)
    

    Create dataframe for row and column color levels

    network_node_colors = pd.DataFrame(network_colors).join(pd.DataFrame(node_colors))
    

    create clustermap

    g = sns.clustermap(networks.corr(),
    # Turn off the clustering
    row_cluster=False, col_cluster=False,
    # Add colored class labels using data frame created from node and network colors
    row_colors = network_node_colors,
    col_colors = network_node_colors,
    # Make the plot look better when many rows/cols
    linewidths=0,
    xticklabels=False, yticklabels=False,
    center=0, cmap="vlag")
    

    create two legends - one for each level by creating invisible column and row barplots (as per above)

    network legend

    from matplotlib.pyplot import gcf
    
    for label in network_labels.unique():
        g.ax_col_dendrogram.bar(0, 0, color=network_lut[label], label=label, linewidth=0)
    
    l1 = g.ax_col_dendrogram.legend(title='Network', loc="center", ncol=5, bbox_to_anchor=(0.47, 0.8), bbox_transform=gcf().transFigure)
    

    node legend

    for label in node_labels.unique():
        g.ax_row_dendrogram.bar(0, 0, color=node_lut[label], label=label, linewidth=0)
    
    l2 = g.ax_row_dendrogram.legend(title='Node', loc="center", ncol=2, bbox_to_anchor=(0.8, 0.8), bbox_transform=gcf().transFigure)
    
    plt.show()
    

    0 讨论(0)
  • 2020-12-01 02:55

    When both dendrograms are used one can also add a new hidden axis and draw the legend.

    ax= f.add_axes((0,0,0,0))
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    
    for label in node_labels.unique():
        ax.bar(0, 0, color=node_lut[label], label=label, linewidth=0)
    
    l2 = g.ax_row_dendrogram.legend(title='Node', loc="center", ncol=2, bbox_to_anchor=(0.8, 0.8), bbox_transform=f.transFigure)
    
    0 讨论(0)
  • 2020-12-01 03:04

    There are two options:

    First, heatmap is an Axes level figure, so you could set up a main large main heatmap axes for the correlation matrix and flank it with heatmaps that you then pass class colors to yourself. This will be a little bit of work, but gives you lots of control over how everything works.

    This is more or less an option in clustermap though, so I'm going to demonstrate how to do it that way here. It's a bit of a hack, but it will work.

    First, we'll load the sample data and do a bit of roundabout transformations to get colors for the class labels.

    networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
    network_labels = networks.columns.get_level_values("network")
    network_pal = sns.cubehelix_palette(network_labels.unique().size,
                                        light=.9, dark=.1, reverse=True,
                                        start=1, rot=-2)
    network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
    
    network_colors = pd.Series(network_labels).map(network_lut)
    

    Next we call clustermap to make the main plot.

    g = sns.clustermap(networks.corr(),
    
                      # Turn off the clustering
                      row_cluster=False, col_cluster=False,
    
                      # Add colored class labels
                      row_colors=network_colors, col_colors=network_colors,
    
                      # Make the plot look better when many rows/cols
                      linewidths=0, xticklabels=False, yticklabels=False)
    

    The side colors are drawn with a heatmap, which matplotlib thinks of as quantitative data and thus there's not a straightforward way to get a legend directly from it. Instead of that, we'll add an invisible barplot with the right colors and labels, then add a legend for that.

    for label in network_labels.unique():
        g.ax_col_dendrogram.bar(0, 0, color=network_lut[label],
                                label=label, linewidth=0)
    g.ax_col_dendrogram.legend(loc="center", ncol=6)
    

    Finally, let's move the colorbar to take up the empty space where the row dendrogram would normally be and save the figure.

    g.cax.set_position([.15, .2, .03, .45])
    g.savefig("clustermap.png")
    

    enter image description here

    0 讨论(0)
提交回复
热议问题