可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
I have the following plot:
fig,ax = plt.subplots(5,2,sharex=True,sharey=True,figsize=fig_size)
and now I would like to give this plot common x-axis labels and y-axis labels. With "common", I mean that there should be one big x-axis label below the whole grid of subplots, and one big y-axis label to the right. I can't find anything about this in the documentation for plt.subplots
, and my googlings suggest that I need to make a big plt.subplot(111)
to start with - but how do I then put my 5*2 subplots into that using plt.subplots
?
回答1:
This looks like what you actually want. It applies the same approach of this answer to your specific case:
import matplotlib.pyplot as plt fig, ax = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, figsize=(6, 6)) fig.text(0.5, 0.04, 'common X', ha='center') fig.text(0.04, 0.5, 'common Y', va='center', rotation='vertical')
回答2:
Without sharex=True, sharey=True
you get:
With it you should get it nicer:
fig, axes2d = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, figsize=(6,6)) for i, row in enumerate(axes2d): for j, cell in enumerate(row): cell.imshow(np.random.rand(32,32)) plt.tight_layout()
But if you want to add additional labels, you should add them only to the edge plots:
fig, axes2d = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, figsize=(6,6)) for i, row in enumerate(axes2d): for j, cell in enumerate(row): cell.imshow(np.random.rand(32,32)) if i == len(axes2d) - 1: cell.set_xlabel("noise column: {0:d}".format(j + 1)) if j == 0: cell.set_ylabel("noise row: {0:d}".format(i + 1)) plt.tight_layout()
Adding label for each plot would spoil it (maybe there is a way to automatically detect repeated labels, but I am not aware of one).
回答3:
Since the command:
fig,ax = plt.subplots(5,2,sharex=True,sharey=True,figsize=fig_size)
you used returns a tuple consisting of the figure and a list of the axes instances, it is already sufficient to do something like (mind that I've changed fig,ax
to fig,axes
):
fig,axes = plt.subplots(5,2,sharex=True,sharey=True,figsize=fig_size) for ax in axes: ax.set_xlabel('Common x-label') ax.set_ylabel('Common y-label')
If you happen to want to change some details on a specific subplot, you can access it via axes[i]
where i
iterates over your subplots.
It might also be very helpful to include a
fig.tight_layout()
at the end of the file, before the plt.show()
, in order to avoid overlapping labels.
回答4:
I ran into a similar problem while plotting a grid of graphs. The graphs consisted of two parts (top and bottom). The y-label was supposed to be centered over both parts.
I did not want to use a solution that depends on knowing the position in the outer figure (like fig.text()), so I manipulated the y-position of the set_ylabel() function. It is usually 0.5, the middle of the plot it is added to. As the padding between the parts (hspace) in my code was zero, I could calculate the middle of the two parts relative to the upper part.
import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec # Create outer and inner grid outerGrid = gridspec.GridSpec(2, 3, width_ratios=[1,1,1], height_ratios=[1,1]) somePlot = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=outerGrid[3], height_ratios=[1,3], hspace = 0) # Add two partial plots partA = plt.subplot(somePlot[0]) partB = plt.subplot(somePlot[1]) # No x-ticks for the upper plot plt.setp(partA.get_xticklabels(), visible=False) # The center is (height(top)-height(bottom))/(2*height(top)) # Simplified to 0.5 - height(bottom)/(2*height(top)) mid = 0.5-somePlot.get_height_ratios()[1]/(2.*somePlot.get_height_ratios()[0]) # Place the y-label partA.set_ylabel('shared label', y = mid) plt.show()
picture
Downsides:
The horizontal distance to the plot is based on the top part, the bottom ticks might extend into the label.
The formula does not take space between the parts into account.
Throws an exception when the height of the top part is 0.
There is probably a general solution that takes padding between figures into account.
回答5:
I discovered an alternative method; if you know the bottom
and top
kwargs that went into a GridSpec
initialization, or you otherwise know the edges positions of your axes in Figure
coordinates, you can also specify the ylabel position in Figure
coordinates with some fancy "transform" magic. For example:
import matplotlib.transforms as mtransforms bottom, top = .1, .9 f, a = plt.subplots(nrows=2, ncols=1, bottom=bottom, top=top) avepos = (bottom+top)/2 a[0].yaxis.label.set_transform(mtransforms.blended_transform_factory( mtransforms.IdentityTransform(), f.transFigure # specify x, y transform )) # changed from default blend (IdentityTransform(), a[0].transAxes) a[0].yaxis.label.set_position((0, avepos)) a[0].set_ylabel('Hello, world!')
...and you should see that the label still appropriately adjusts left-right to keep from overlapping with ticklabels, just like normal -- but now it will adjust to be always exactly between the desired subplots.
Furthermore, if you don't even use set_position
, the ylabel will show up by default exactly halfway up the figure. I'm guessing this is because when the label is finally drawn, matplotlib
uses 0.5 for the y
-coordinate without checking whether the underlying coordinate transform has changed.