I want to plot some image side by side in my jupyter notebook. So it can save some space for display. For example
This is done through
fig = plt
use the following align_figures()
:
def align_figures():
import matplotlib
from matplotlib._pylab_helpers import Gcf
from IPython.display import display_html
import base64
from ipykernel.pylab.backend_inline import show
images = []
for figure_manager in Gcf.get_all_fig_managers():
fig = figure_manager.canvas.figure
png = get_ipython().display_formatter.format(fig)[0]['image/png']
src = base64.encodebytes(png).decode()
images.append('<img style="margin:0" align="left" src="data:image/png;base64,{}"/>'.format(src))
html = "<div>{}</div>".format("".join(images))
show._draw_called = False
matplotlib.pyplot.close('all')
display_html(html, raw=True)
Here is a test:
fig1, ax1 = pl.subplots(figsize=(4, 3))
fig2, ax2 = pl.subplots(figsize=(4, 3))
fig3, ax3 = pl.subplots(figsize=(4, 3))
align_figures()
The code assumes that the output format is PNG image.
first let me recommend you use a colormap other than the jet colormap for the reasons detailed in A better colormap for matplotlib.
As to what you want to do you can achieve this with a modified code from: https://stackoverflow.com/a/26432947/835607
I've extended that function to handle the zaxis of 3d plots as well as the colorbars you are using.
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.transforms import Bbox
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
def full_extent(ax, xpad=0.0, ypad=0.0, cbar=None):
"""Modified from https://stackoverflow.com/a/26432947/835607
Get the full extent of an axes, including axes labels, tick labels, and
titles.
You may need to pad the x or y dimension in order to not get slightly chopped off labels
For text objects, we need to draw the figure first, otherwise the extents
are undefined. These draws can be eliminated by calling plt.show() prior
to calling this function."""
ax.figure.canvas.draw()
items = ax.get_xticklabels() + ax.get_yticklabels()
items += [ax, ax.title, ax.xaxis.label, ax.yaxis.label]
if '3D' in str(type(ax)):
items += ax.get_zticklabels() +[ax.zaxis.label]
if cbar:
items+=cbar.ax.get_yticklabels()
bbox = Bbox.union([cbar.ax.get_window_extent()]+[item.get_window_extent() for item in items])
else:
bbox = Bbox.union([item.get_window_extent() for item in items])
return bbox.expanded(1.0 + xpad, 1.0 + ypad)
Now for an example I plot 3 subplots and save them all to separate files. Note that the full_extent function has cbar, xpad,
and ypad
as arguments. For the plots that have colorbars make sure to pass the colorbar axes object to the function. You may also need to play around with the padding to get the best results.
# Make an example plot with 3 subplots...
fig = plt.figure(figsize=(9,4))
#3D Plot
ax1 = fig.add_subplot(1,3,1,projection='3d')
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)
surf = ax1.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='viridis',
linewidth=0, antialiased=False)
ax1.set_zlim(-1.01, 1.01)
ax1.zaxis.set_major_locator(LinearLocator(10))
ax1.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
# This plot has a colorbar that we'll need to pass to extent
ax2 = fig.add_subplot(1,3,2)
data = np.clip(np.random.randn(250, 250), -1, 1)
cax = ax2.imshow(data, interpolation='nearest', cmap='viridis')
ax2.set_title('Gaussian noise')
cbar = fig.colorbar(cax)
ax2.set_xlabel('asdf')
ax2.set_ylabel('Some Cool Data')
#3rd plot for fun
ax3 = fig.add_subplot(1,3,3)
ax3.plot([1,4,5,7,7],[3,5,7,8,3],'ko--')
ax3.set_ylabel('adsf')
ax3.set_title('a title')
plt.tight_layout() #no overlapping labels
plt.show() #show in notebook also give text an extent
fig.savefig('full_figure.png') #just in case
# Save just the portion _inside_ the boundaries of each axis
extent1 = full_extent(ax1).transformed(fig.dpi_scale_trans.inverted())
fig.savefig('ax1_figure.png', bbox_inches=extent1)
extent2 = full_extent(ax2,.05,.1,cbar).transformed(fig.dpi_scale_trans.inverted())
fig.savefig('ax2_figure.png', bbox_inches=extent2)
extent3 = full_extent(ax3).transformed(fig.dpi_scale_trans.inverted())
fig.savefig('ax3_figure.png', bbox_inches=extent3)
This plots the three plots on one line as you wanted and creates cropped output images such as this one: