问题
I have a function that takes an arbitrary length 3D data set of dates, prices(float), and some resulting value(float) and makes a set of seaborn heatmaps split by year. The pseudocode is as follows (note the number of years varies by dataset so I need it to arbitrarily scale):
def makePlots(data):
split data by year
fig,axs=plt.subplots(1, numYears)
x=0
for year in years
sns.heatmap(data[year], ax = axs[x++])
return axs
this outputs a single matplotlib figure with a heatmap for each year next to each other on a single line, as shown in this example: single plotted dataset
Now I have a higher level function in which I feed two data sets (each arbitrary amount of years) and have it print the heatmap plots for each above one another for comparison. I would like it to somehow take the figures made by the makePlots
method and just stack them on top of one another, as in this example: two plotted datasets
def compareData(data1,data2):
fig1 = makePlots(data1)
fig2 = makePlots(data2)
fig, (ax1,ax2) = plt.subplots(2,1)
ax1 = fig1
ax2 = fig2
plt.show()
Now this code works, however not as intended. It opens up 3 new plot windows, one with data1 plotted correctly, one with data2 plotted correctly, and one with an empty 2 row subplot. Is there any way to nest the makePlots plots within a new subplot one on top of the other? I have also tried returning plt.gcf()
. All the other answers on stack overflow depend on passing the axes to the plot method but given that I have an arbitrary amount of axes (years) per dataset and eventually would like to compare an arbitrary amount of datasets, this seems not ideal (not that I can figure out an implementation of that anyways since each row can have an arbitrary amount of years).
回答1:
I wouldn't recommend it but you can add subplots incrementally by using fig.add_subplot(nrow, ncol, index)
.
So your two functions would look something like this:
def compareData(data1, data2):
fig = plt.figure()
makePlots(data1, row=0, fig=fig)
makePlots(data2, row=1, fig=fig)
def makePlots(data, row, fig):
years = ... # parse data here
for ii, year in enumerate(years):
ax = fig.add_subplot(2, len(years), row * len(years) + ii + 1)
sns.heatmap(data[year], ax=ax)
This hopefully addresses your question.
However, you are only having this problem because your are mixing data parsing and plotting in the same function. My advice would be to first parse the data, then pass the new data structure into some plotting functions.
来源:https://stackoverflow.com/questions/61839595/nesting-or-combining-matplotlib-figures-and-plots