More idiomatic way to display images in a grid with numpy

后端 未结 3 1452
青春惊慌失措
青春惊慌失措 2020-12-31 09:40

Is there a more idiomatic way to display a grid of images as in the below example?

import numpy as np

def gallery(array, ncols=3):
    nrows = np.math.ceil(         


        
相关标签:
3条回答
  • 2020-12-31 09:59

    This answer is based off @unutbu's, but this deals with HWC ordered tensors. Furthermore, it shows black tiles for any channels that do not factorize evenly into the given rows/columns.

    def tile(arr, nrows, ncols):
        """
        Args:
            arr: HWC format array
            nrows: number of tiled rows
            ncols: number of tiled columns
        """
        h, w, c = arr.shape
        out_height = nrows * h
        out_width = ncols * w
        chw = np.moveaxis(arr, (0, 1, 2), (1, 2, 0))
    
        if c < nrows * ncols:
            chw = chw.reshape(-1).copy()
            chw.resize(nrows * ncols * h * w)
    
        return (chw
            .reshape(nrows, ncols, h, w)
            .swapaxes(1, 2)
            .reshape(out_height, out_width))
    

    Here's a corresponding detiling function for the reverse direction:

    def detile(arr, nrows, ncols, c, h, w):
        """
        Args:
            arr: tiled array
            nrows: number of tiled rows
            ncols: number of tiled columns
            c: channels (number of tiles to keep)
            h: height of tile
            w: width of tile
        """
        chw = (arr
            .reshape(nrows, h, ncols, w)
            .swapaxes(1, 2)
            .reshape(-1)[:c*h*w]
            .reshape(c, h, w))
    
        return np.moveaxis(chw, (0, 1, 2), (2, 0, 1)).reshape(h, w, c)
    
    0 讨论(0)
  • 2020-12-31 10:01
    import numpy as np
    import matplotlib.pyplot as plt
    
    def gallery(array, ncols=3):
        nindex, height, width, intensity = array.shape
        nrows = nindex//ncols
        assert nindex == nrows*ncols
        # want result.shape = (height*nrows, width*ncols, intensity)
        result = (array.reshape(nrows, ncols, height, width, intensity)
                  .swapaxes(1,2)
                  .reshape(height*nrows, width*ncols, intensity))
        return result
    
    def make_array():
        from PIL import Image
        return np.array([np.asarray(Image.open('face.png').convert('RGB'))]*12)
    
    array = make_array()
    result = gallery(array)
    plt.imshow(result)
    plt.show()
    

    yields


    We have an array of shape (nrows*ncols, height, weight, intensity). We want an array of shape (height*nrows, width*ncols, intensity).

    So the idea here is to first use reshape to split apart the first axis into two axes, one of length nrows and one of length ncols:

    array.reshape(nrows, ncols, height, width, intensity)
    

    This allows us to use swapaxes(1,2) to reorder the axes so that the shape becomes (nrows, height, ncols, weight, intensity). Notice that this places nrows next to height and ncols next to width.

    Since reshape does not change the raveled order of the data, reshape(height*nrows, width*ncols, intensity) now produces the desired array.

    This is (in spirit) the same as the idea used in the unblockshaped function.

    0 讨论(0)
  • 2020-12-31 10:10

    Another way is to use view_as_blocks . Then you avoid to swap axes by hand :

    from skimage.util import view_as_blocks
    import numpy as np
    
    def refactor(im_in,ncols=3):
        n,h,w,c = im_in.shape
        dn = (-n)%ncols # trailing images
        im_out = (np.empty((n+dn)*h*w*c,im_in.dtype)
               .reshape(-1,w*ncols,c))
        view=view_as_blocks(im_out,(h,w,c))
        for k,im in enumerate( list(im_in) + dn*[0] ):
            view[k//ncols,k%ncols,0] = im 
        return im_out
    
    0 讨论(0)
提交回复
热议问题