Jupyter Notebook: Output image in previous line

前端 未结 2 454
灰色年华
灰色年华 2021-01-26 19:11

I want to plot some image side by side in my jupyter notebook. So it can save some space for display. For example

This is done through

fig = plt         


        
相关标签:
2条回答
  • 2021-01-26 20:12

    use the following align_figures():

    def align_figures():
        import matplotlib
        from matplotlib._pylab_helpers import Gcf
        from IPython.display import display_html
        import base64
        from ipykernel.pylab.backend_inline import show
    
        images = []
        for figure_manager in Gcf.get_all_fig_managers():
            fig = figure_manager.canvas.figure
            png = get_ipython().display_formatter.format(fig)[0]['image/png']
            src = base64.encodebytes(png).decode()
            images.append('<img style="margin:0" align="left" src="data:image/png;base64,{}"/>'.format(src))
    
        html = "<div>{}</div>".format("".join(images))
        show._draw_called = False
        matplotlib.pyplot.close('all')
        display_html(html, raw=True)
    

    Here is a test:

    fig1, ax1 = pl.subplots(figsize=(4, 3))
    fig2, ax2 = pl.subplots(figsize=(4, 3))
    fig3, ax3 = pl.subplots(figsize=(4, 3))
    align_figures()
    

    The code assumes that the output format is PNG image.

    0 讨论(0)
  • 2021-01-26 20:12

    first let me recommend you use a colormap other than the jet colormap for the reasons detailed in A better colormap for matplotlib.

    As to what you want to do you can achieve this with a modified code from: https://stackoverflow.com/a/26432947/835607

    I've extended that function to handle the zaxis of 3d plots as well as the colorbars you are using.

    %matplotlib inline
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.transforms import Bbox
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib.ticker import LinearLocator, FormatStrFormatter
    
    def full_extent(ax, xpad=0.0, ypad=0.0, cbar=None):
        """Modified from https://stackoverflow.com/a/26432947/835607
    
        Get the full extent of an axes, including axes labels, tick labels, and
        titles.
        You may need to pad the x or y dimension in order to not get slightly chopped off labels
    
        For text objects, we need to draw the figure first, otherwise the extents
        are undefined. These draws can be eliminated by calling plt.show() prior 
        to calling this function."""
    
        ax.figure.canvas.draw()
        items = ax.get_xticklabels() + ax.get_yticklabels() 
        items += [ax, ax.title, ax.xaxis.label, ax.yaxis.label]
        if '3D' in str(type(ax)):  
            items += ax.get_zticklabels() +[ax.zaxis.label]
        if cbar:
            items+=cbar.ax.get_yticklabels()
            bbox = Bbox.union([cbar.ax.get_window_extent()]+[item.get_window_extent() for item in items])
        else:
             bbox = Bbox.union([item.get_window_extent() for item in items])
        return bbox.expanded(1.0 + xpad, 1.0 + ypad)
    

    Now for an example I plot 3 subplots and save them all to separate files. Note that the full_extent function has cbar, xpad, and ypad as arguments. For the plots that have colorbars make sure to pass the colorbar axes object to the function. You may also need to play around with the padding to get the best results.

    # Make an example plot with 3 subplots...
    fig = plt.figure(figsize=(9,4))
    
    #3D Plot
    ax1 = fig.add_subplot(1,3,1,projection='3d')
    X = np.arange(-5, 5, 0.25)
    Y = np.arange(-5, 5, 0.25)
    X, Y = np.meshgrid(X, Y)
    R = np.sqrt(X**2 + Y**2)
    Z = np.sin(R)
    surf = ax1.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='viridis',
                           linewidth=0, antialiased=False)
    ax1.set_zlim(-1.01, 1.01)
    ax1.zaxis.set_major_locator(LinearLocator(10))
    ax1.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
    
    # This plot has a colorbar that we'll need to pass to extent
    ax2 = fig.add_subplot(1,3,2)
    data = np.clip(np.random.randn(250, 250), -1, 1)
    cax = ax2.imshow(data, interpolation='nearest', cmap='viridis')
    ax2.set_title('Gaussian noise')
    cbar = fig.colorbar(cax)
    ax2.set_xlabel('asdf')
    ax2.set_ylabel('Some Cool Data')
    
    
    #3rd plot for fun
    ax3 = fig.add_subplot(1,3,3)
    ax3.plot([1,4,5,7,7],[3,5,7,8,3],'ko--')
    ax3.set_ylabel('adsf')
    ax3.set_title('a title')
    
    
    plt.tight_layout() #no overlapping labels
    plt.show()  #show in notebook also give text an extent
    fig.savefig('full_figure.png') #just in case
    
    # Save just the portion _inside_ the boundaries of each axis
    extent1 = full_extent(ax1).transformed(fig.dpi_scale_trans.inverted())
    fig.savefig('ax1_figure.png', bbox_inches=extent1)
    
    extent2 = full_extent(ax2,.05,.1,cbar).transformed(fig.dpi_scale_trans.inverted())
    fig.savefig('ax2_figure.png', bbox_inches=extent2)
    
    extent3 = full_extent(ax3).transformed(fig.dpi_scale_trans.inverted())
    fig.savefig('ax3_figure.png', bbox_inches=extent3)
    

    This plots the three plots on one line as you wanted and creates cropped output images such as this one:

    0 讨论(0)
提交回复
热议问题