可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
I use matplotlib.pyplot.pcolor() to plot a heatmap with matplotlib:
import numpy as np import matplotlib.pyplot as plt def heatmap(data, title, xlabel, ylabel): plt.figure() plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0) plt.colorbar(c) def main(): title = "ROC's AUC" xlabel= "Timeshift" ylabel="Scales" data = np.random.rand(8,12) heatmap(data, title, xlabel, ylabel) plt.show() if __name__ == "__main__": main()
Is any way to add the corresponding value in each cell, e.g.:
(from Matlab's Customizable Heat Maps)
(I don't need the additional %
for my current application, though I'd be curious to know for the future)
回答1:
You need to add all the text by calling axes.text()
, here is an example:
import numpy as np import matplotlib.pyplot as plt title = "ROC's AUC" xlabel= "Timeshift" ylabel="Scales" data = np.random.rand(8,12) plt.figure(figsize=(12, 6)) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0) def show_values(pc, fmt="%.2f", **kw): from itertools import izip pc.update_scalarmappable() ax = pc.get_axes() for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()): x, y = p.vertices[:-2, :].mean(0) if np.all(color[:3] > 0.5): color = (0.0, 0.0, 0.0) else: color = (1.0, 1.0, 1.0) ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw) show_values(c) plt.colorbar(c)
the output:
回答2:
You could use Seaborn, which is a Python visualization library based on matplotlib that provides a high-level interface for drawing attractive statistical graphics.
Heatmap example:
import seaborn as sns sns.set() flights_long = sns.load_dataset("flights") flights = flights_long.pivot("month", "year", "passengers") sns.heatmap(flights, annot=True, fmt="d") # To display the heatmap import matplotlib.pyplot as plt plt.show() # To save the heatmap as a file: fig = heatmap.get_figure() fig.savefig('heatmap.pdf')
Documentation: https://seaborn.pydata.org/generated/seaborn.heatmap.html
回答3:
If that's of interest to anyone, here is below the code I use to imitate the picture from Matlab's Customizable Heat Maps I had included in the question).
import numpy as np import matplotlib.pyplot as plt def show_values(pc, fmt="%.2f", **kw): ''' Heatmap with text in each cell with matplotlib's pyplot Source: http://stackoverflow.com/a/25074150/395857 By HYRY ''' from itertools import izip pc.update_scalarmappable() ax = pc.get_axes() for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()): x, y = p.vertices[:-2, :].mean(0) if np.all(color[:3] > 0.5): color = (0.0, 0.0, 0.0) else: color = (1.0, 1.0, 1.0) ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw) def cm2inch(*tupl): ''' Specify figure size in centimeter in matplotlib Source: http://stackoverflow.com/a/22787457/395857 By gns-ank ''' inch = 2.54 if type(tupl[0]) == tuple: return tuple(i/inch for i in tupl[0]) else: return tuple(i/inch for i in tupl) def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels): ''' Inspired by: - http://stackoverflow.com/a/16124677/395857 - http://stackoverflow.com/a/25074150/395857 ''' # Plot it out fig, ax = plt.subplots() c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0) # put the major ticks at the middle of each cell ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False) ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False) # set tick labels #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False) ax.set_xticklabels(xticklabels, minor=False) ax.set_yticklabels(yticklabels, minor=False) # set title and x/y labels plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) # Remove last blank column plt.xlim( (0, AUC.shape[1]) ) # Turn off all the ticks ax = plt.gca() for t in ax.xaxis.get_major_ticks(): t.tick1On = False t.tick2On = False for t in ax.yaxis.get_major_ticks(): t.tick1On = False t.tick2On = False # Add color bar plt.colorbar(c) # Add text in each cell show_values(c) # resize fig = plt.gcf() fig.set_size_inches(cm2inch(40, 20)) def main(): x_axis_size = 19 y_axis_size = 10 title = "ROC's AUC" xlabel= "Timeshift" ylabel="Scales" data = np.random.rand(y_axis_size,x_axis_size) xticklabels = range(1, x_axis_size+1) # could be text yticklabels = range(1, y_axis_size+1) # could be text heatmap(data, title, xlabel, ylabel, xticklabels, yticklabels) plt.savefig('image_output.png', dpi=300, format='png', bbox_inches='tight') # use format='svg' or 'pdf' for vectorial pictures plt.show() if __name__ == "__main__": main() #cProfile.run('main()') # if you want to do some profiling
Output:
It looks nicer when there are some patterns: