Annotate heatmap with value from Pandas dataframe

前端 未结 4 1039
野的像风
野的像风 2021-02-04 20:59

I would like to annotate a heatmap with the values that I pass from a dataframe into the function below. I have looked at matplotlib.text but have not been able to get the value

相关标签:
4条回答
  • 2021-02-04 21:31

    The values you were using for your coordinates in your for loop were screwed up. Also you were using plt.colorbar instead of something cleaner like fig.colorbar. Try this (it gets the job done, with no effort made to otherwise cleanup the code):

    def heatmap_binary(df,
                edgecolors='w',
                #cmap=mpl.cm.RdYlGn,
                log=False):    
        width = len(df.columns)/7*10
        height = len(df.index)/7*10
    
        fig, ax = plt.subplots(figsize=(20,10))#(figsize=(width,height))
    
        cmap, norm = mcolors.from_levels_and_colors([0, 0.05, 1],['Teal', 'MidnightBlue'] ) # ['MidnightBlue', Teal]['Darkgreen', 'Darkred']
    
        heatmap = ax.pcolor(df ,
                            edgecolors=edgecolors,  # put white lines between squares in heatmap
                            cmap=cmap,
                            norm=norm)
        data = df.values
        for y in range(data.shape[0]):
            for x in range(data.shape[1]):
                plt.text(x + 0.5 , y + 0.5, '%.4f' % data[y, x], #data[y,x] +0.05 , data[y,x] + 0.05
                     horizontalalignment='center',
                     verticalalignment='center',
                     color='w')
    
    
        ax.autoscale(tight=True)  # get rid of whitespace in margins of heatmap
        ax.set_aspect('equal')  # ensure heatmap cells are square
        ax.xaxis.set_ticks_position('top')  # put column labels at the top
        ax.tick_params(bottom='off', top='off', left='off', right='off')  # turn off ticks
    
        ax.set_yticks(np.arange(len(df.index)) + 0.5)
        ax.set_yticklabels(df.index, size=20)
        ax.set_xticks(np.arange(len(df.columns)) + 0.5)
        ax.set_xticklabels(df.columns, rotation=90, size= 15)
    
        # ugliness from http://matplotlib.org/users/tight_layout_guide.html
        from mpl_toolkits.axes_grid1 import make_axes_locatable
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", "3%", pad="1%")
        fig.colorbar(heatmap, cax=cax)
    

    Then

    df1 = pd.DataFrame(np.random.choice([0, 0.75], size=(4,5)), columns=list('ABCDE'), index=list('WXYZ'))
    heatmap_binary(df1)
    

    gives:

    The Answer

    0 讨论(0)
  • 2021-02-04 21:36

    This is because you're using plt.text after you've added another axes.

    The state machine will plot on the current axes, and after you've added a new one with divider.append_axes, the colorbar's axes is the current one. (Just calling plt.colorbar will not cause this, as it sets the current axes back to the original one afterwards if it creates the axes itself. If a specific axes object is passed in using the cax kwarg, it doesn't reset the "current" axes, as that's not what you'd normally want.)

    Things like this are the main reason that you'll see so many people advising that you use the OO interface to matplotlib instead of the state machine interface. That way you know which axes object that you're plotting on.

    For example, in your case, you could have heatmap_binary return the ax object that it creates, and the plot using ax.text instead of plt.text (and similar for the other plotting methods).

    0 讨论(0)
  • 2021-02-04 21:37

    This functionality is provided by the seaborn package. It can produce maps like

    An example usage of seaborn is

    import seaborn as sns
    sns.set()
    
    # Load the example flights dataset and conver to long-form
    flights_long = sns.load_dataset("flights")
    flights = flights_long.pivot("month", "year", "passengers")
    
    # Draw a heatmap with the numeric values in each cell
    sns.heatmap(flights, annot=True, fmt="d", linewidths=.5)
    
    0 讨论(0)
  • 2021-02-04 21:55

    You also can use plotly.figure_factory to create heatmap from DataFrame, but you have convert it into list.

        import plotly.figure_factory as ff
    
        z = [your_dataframe].values.tolist()
        x = [your_dataframe].columns.tolist()
        y = [your_dataframe].index.tolist()
    
        fig = ff.create_annotated_heatmap(z, x=x, y=y, annotation_text=z, colorscale='viridis')
    
        # for add annotation into Heatmap
        for i in range(len(fig.layout.annotations)):
            fig.layout.annotations[i].font.size = 12
    
        # show your Heatmap
        fig.show()
    
    0 讨论(0)
提交回复
热议问题