matplotlib iterate subplot axis array through single list

前端 未结 6 1273
醉话见心
醉话见心 2020-12-13 04:08

Is there a simple/clean way to iterate an array of axis returned by subplots like

nrow = ncol = 2
a = []
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)
for         


        
相关标签:
6条回答
  • 2020-12-13 04:35

    I am not sure when it was added, but there is now a squeeze keyword argument. This makes sure the result is always a 2D numpy array. Turning that into a 1D array is easy:

    fig, ax2d = subplots(2, 2, squeeze=False)
    axli = ax2d.flatten()
    

    Works for any number of subplots, no trick for single ax, so a little easier than the accepted answer (perhaps squeeze didn't exist yet back then).

    0 讨论(0)
  • 2020-12-13 04:39

    The fig return value of plt.subplots has a list of all the axes. To iterate over all the subplots in a figure you can use:

    nrow = 2
    ncol = 2
    fig, axs = plt.subplots(nrow, ncol)
    for i, ax in enumerate(fig.axes):
        ax.set_ylabel(str(i))
    

    This also works for nrow == ncol == 1.

    0 讨论(0)
  • 2020-12-13 04:47

    Matplotlib has its own flatten function on axes.

    Why don't you try following code?

    fig, axes = plt.subplots(2, 3)
    for ax in axes.flat:
        ## do something with instance of 'ax'
    
    0 讨论(0)
  • 2020-12-13 04:48

    TLDR; axes.flat is the most pythonic way of iterating through axes

    As others have pointed out, the return value of plt.subplots() is a numpy array of Axes objects, thus there are a ton of built-in numpy methods for flattening the array. Of those options axes.flat is the least verbose access method. Furthermore, axes.flatten() returns a copy of the array whereas axes.flat returns an iterator to the array. This means axes.flat will be more efficient in the long run.

    Stealing @Sukjun-Kim's example:

    fig, axes = plt.subplots(2, 3)
    for ax in axes.flat:
        ## do something with instance of 'ax'
    

    sources: axes.flat docs Matplotlib tutorial

    0 讨论(0)
  • 2020-12-13 04:54

    The ax return value is a numpy array, which can be reshaped, I believe, without any copying of the data. If you use the following, you'll get a linear array that you can iterate over cleanly.

    nrow = 1; ncol = 2;
    fig, axs = plt.subplots(nrows=nrow, ncols=ncol)
    
    for ax in axs.reshape(-1): 
      ax.set_ylabel(str(i))
    

    This doesn't hold when ncols and nrows are both 1, since the return value is not an array; you could turn the return value into an array with one element for consistency, though it feels a bit like a cludge:

    nrow = 1; ncol = 1;
    fig, axs = plt.subplots(nrows=nrow, ncols=nrow)
    axs = np.array(axs)
    
    for ax in axs.reshape(-1):
      ax.set_ylabel(str(i))
    

    reshape docs. The argument -1 causes reshape to infer dimensions of the output.

    0 讨论(0)
  • 2020-12-13 04:54

    Here is a good practice:
    For example, we need a set of 4 by 4 subplots so we can have them like below:

    rows = 4; cols = 4;
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(20, 16), squeeze=0, sharex=True, sharey=True)
    axes = np.array(axes)
    
    for i, ax in enumerate(axes.reshape(-1)):
      ax.set_ylabel(f'Subplot: {i}')
    

    The output is beautiful and clear.

    0 讨论(0)
提交回复
热议问题