Heatmap with text in each cell with matplotlib's pyplot

匿名 (未验证) 提交于 2019-12-03 01:54:01

问题:

I use matplotlib.pyplot.pcolor() to plot a heatmap with matplotlib:

import numpy as np import matplotlib.pyplot as plt      def heatmap(data, title, xlabel, ylabel):     plt.figure()     plt.title(title)     plt.xlabel(xlabel)     plt.ylabel(ylabel)     c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)     plt.colorbar(c)  def main():     title = "ROC's AUC"     xlabel= "Timeshift"     ylabel="Scales"     data =  np.random.rand(8,12)     heatmap(data, title, xlabel, ylabel)     plt.show()  if __name__ == "__main__":     main() 

Is any way to add the corresponding value in each cell, e.g.:

(from Matlab's Customizable Heat Maps)

(I don't need the additional % for my current application, though I'd be curious to know for the future)

回答1:

You need to add all the text by calling axes.text(), here is an example:

import numpy as np import matplotlib.pyplot as plt      title = "ROC's AUC" xlabel= "Timeshift" ylabel="Scales" data =  np.random.rand(8,12)   plt.figure(figsize=(12, 6)) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)  def show_values(pc, fmt="%.2f", **kw):     from itertools import izip     pc.update_scalarmappable()     ax = pc.get_axes()     for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):         x, y = p.vertices[:-2, :].mean(0)         if np.all(color[:3] > 0.5):             color = (0.0, 0.0, 0.0)         else:             color = (1.0, 1.0, 1.0)         ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)  show_values(c)  plt.colorbar(c) 

the output:



回答2:

You could use Seaborn, which is a Python visualization library based on matplotlib that provides a high-level interface for drawing attractive statistical graphics.

Heatmap example:

import seaborn as sns sns.set()  flights_long = sns.load_dataset("flights") flights = flights_long.pivot("month", "year", "passengers")  sns.heatmap(flights, annot=True, fmt="d")  # To display the heatmap  import matplotlib.pyplot as plt plt.show()  # To save the heatmap as a file: fig = heatmap.get_figure() fig.savefig('heatmap.pdf') 

Documentation: https://seaborn.pydata.org/generated/seaborn.heatmap.html



回答3:

If that's of interest to anyone, here is below the code I use to imitate the picture from Matlab's Customizable Heat Maps I had included in the question).

import numpy as np import matplotlib.pyplot as plt   def show_values(pc, fmt="%.2f", **kw):     '''     Heatmap with text in each cell with matplotlib's pyplot     Source: http://stackoverflow.com/a/25074150/395857      By HYRY     '''     from itertools import izip     pc.update_scalarmappable()     ax = pc.get_axes()     for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):         x, y = p.vertices[:-2, :].mean(0)         if np.all(color[:3] > 0.5):             color = (0.0, 0.0, 0.0)         else:             color = (1.0, 1.0, 1.0)         ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)  def cm2inch(*tupl):     '''     Specify figure size in centimeter in matplotlib     Source: http://stackoverflow.com/a/22787457/395857     By gns-ank     '''     inch = 2.54     if type(tupl[0]) == tuple:         return tuple(i/inch for i in tupl[0])     else:         return tuple(i/inch for i in tupl)  def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels):     '''     Inspired by:     - http://stackoverflow.com/a/16124677/395857      - http://stackoverflow.com/a/25074150/395857     '''      # Plot it out     fig, ax = plt.subplots()         c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)      # put the major ticks at the middle of each cell     ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)     ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)      # set tick labels     #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)     ax.set_xticklabels(xticklabels, minor=False)     ax.set_yticklabels(yticklabels, minor=False)      # set title and x/y labels     plt.title(title)     plt.xlabel(xlabel)     plt.ylabel(ylabel)            # Remove last blank column     plt.xlim( (0, AUC.shape[1]) )      # Turn off all the ticks     ax = plt.gca()         for t in ax.xaxis.get_major_ticks():         t.tick1On = False         t.tick2On = False     for t in ax.yaxis.get_major_ticks():         t.tick1On = False         t.tick2On = False      # Add color bar     plt.colorbar(c)      # Add text in each cell      show_values(c)      # resize      fig = plt.gcf()     fig.set_size_inches(cm2inch(40, 20))    def main():     x_axis_size = 19     y_axis_size = 10     title = "ROC's AUC"     xlabel= "Timeshift"     ylabel="Scales"     data =  np.random.rand(y_axis_size,x_axis_size)     xticklabels = range(1, x_axis_size+1) # could be text     yticklabels = range(1, y_axis_size+1) # could be text        heatmap(data, title, xlabel, ylabel, xticklabels, yticklabels)     plt.savefig('image_output.png', dpi=300, format='png', bbox_inches='tight') # use format='svg' or 'pdf' for vectorial pictures     plt.show()   if __name__ == "__main__":     main()     #cProfile.run('main()') # if you want to do some profiling 

Output:

It looks nicer when there are some patterns:



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!