matplotlib savefig performance, saving multiple pngs within loop

前提是你 提交于 2019-11-26 10:02:49

问题


I\'m hoping to find a way to optimise the following situation. I have a large contour plot created with imshow of matplotlib. I then want to use this contour plot to create a large number of png images, where each image is a small section of the contour image by changing the x and y limits and the aspect ratio.

So no plot data is changing in the loop, only the axis limits and the aspect ratio are changing between each png image.

The following MWE creates 70 png images in a \"figs\" folder demonstrating the simplified idea. About 80% of the runtime is taken up by fig.savefig(\'figs/\'+filename).

I\'ve looked into the following without coming up with an improvement:

  • An alternative to matplotlib with a focus on speed -- I\'ve struggled to find any examples/documentation of contour/surface plots with similar requirements
  • Multiprocessing -- Similar questions I\'ve seen here appear to require fig = plt.figure() and ax.imshow to be called within the loop, since fig and ax can\'t be pickled. In my case this will be more expensive than any speed gains achieved by implementing multiprocessing.

I\'d appreciate any insight or suggestions you might have.

import numpy as np
import matplotlib as mpl
mpl.use(\'agg\')
import matplotlib.pyplot as plt
import time, os

def make_plot(x, y, fix, ax):
    aspect = np.random.random(1)+y/2.0-x
    xrand = np.random.random(2)*x
    xlim = [min(xrand), max(xrand)]
    yrand = np.random.random(2)*y
    ylim = [min(yrand), max(yrand)]
    filename = \'{:d}_{:d}.png\'.format(x,y)

    ax.set_aspect(abs(aspect[0]))
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig(\'figs/\'+filename)

if not os.path.isdir(\'figs\'):
    os.makedirs(\'figs\')
data = np.random.rand(25, 25)

fig = plt.figure()
ax = fig.add_axes([0., 0., 1., 1.])
# in the real case, imshow is an expensive calculation which can\'t be put inside the loop
ax.imshow(data, interpolation=\'nearest\')

tstart = time.clock()
for i in range(1, 8):
    for j in range(3, 13):
        make_plot(i, j, fig, ax)

print(\'took {:.2f} seconds\'.format(time.clock()-tstart))

回答1:


Since the limitation in this case is the call to plt.savefig() it cannot be optimized a lot. Internally the figure is rendered from scratch and that takes a while. Possibly reducing the number of vertices to be drawn might reduce the time a bit.

The time to run your code on my machine (Win 8, i5 with 4 cores 3.5GHz) is 2.5 seconds. This seems not too bad. One can get a little improvement by using Multiprocessing.

A note about Multiprocessing: It may seem surprising that using the state machine of pyplot inside multiprocessing should work at all. But it does. And in this case here, since every image is based on the same figure and axes object, one does not even have to create new figures and axes.

I modified an answer I gave here a while ago for your case and the total time is roughly halved using multiprocessing and 5 processes on 4 cores. I appended a barplot which shows the effect of multiprocessing.

import numpy as np
#import matplotlib as mpl
#mpl.use('agg') # use of agg seems to slow things down a bit
import matplotlib.pyplot as plt
import multiprocessing
import time, os

def make_plot(d):
    start = time.clock()
    x,y=d
    #using aspect in this way causes a warning for me
    #aspect = np.random.random(1)+y/2.0-x 
    xrand = np.random.random(2)*x
    xlim = [min(xrand), max(xrand)]
    yrand = np.random.random(2)*y
    ylim = [min(yrand), max(yrand)]
    filename = '{:d}_{:d}.png'.format(x,y)
    ax = plt.gca()
    #ax.set_aspect(abs(aspect[0]))
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.savefig('figs/'+filename)
    stop = time.clock()
    return np.array([x,y, start, stop])

if not os.path.isdir('figs'):
    os.makedirs('figs')
data = np.random.rand(25, 25)

fig = plt.figure()
ax = fig.add_axes([0., 0., 1., 1.])
ax.imshow(data, interpolation='nearest')


some_list = []
for i in range(1, 8):
    for j in range(3, 13):
        some_list.append((i,j))


if __name__ == "__main__":
    multiprocessing.freeze_support()
    tstart = time.clock()
    print tstart
    num_proc = 5
    p = multiprocessing.Pool(num_proc)

    nu = p.map(make_plot, some_list)

    tooktime = 'Plotting of {} frames took {:.2f} seconds'
    tooktime = tooktime.format(len(some_list), time.clock()-tstart)
    print tooktime
    nu = np.array(nu)

    plt.close("all")
    fig, ax = plt.subplots(figsize=(8,5))
    plt.suptitle(tooktime)
    ax.barh(np.arange(len(some_list)), nu[:,3]-nu[:,2], 
            height=np.ones(len(some_list)), left=nu[:,2],  align="center")
    ax.set_xlabel("time [s]")
    ax.set_ylabel("image number")
    ax.set_ylim([-1,70])
    plt.tight_layout()
    plt.savefig(__file__+".png")
    plt.show()



来源:https://stackoverflow.com/questions/41037840/matplotlib-savefig-performance-saving-multiple-pngs-within-loop

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!