Find a 3x3 sliding window over an image

前端 未结 3 1901
星月不相逢
星月不相逢 2020-12-09 13:56

I have an image.

I want to obtain a 3x3 window (neighbouring pixels) for every pixel in the image.

I have this Python code:

for x in range(2,         


        
相关标签:
3条回答
  • 2020-12-09 14:09

    I think the following does what you are after. The loop is only over the 9 elements. I'm sure there is a way of vectorizing it, but it's probably not worth the effort.

    import numpy
    
    im = numpy.random.randint(0,50,(5,7))
    
    # idx_2d contains the indices of each position in the array
    idx_2d = numpy.mgrid[0:im.shape[0],0:im.shape[1]]
    
    # We break that into 2 sub arrays
    x_idx = idx_2d[1]
    y_idx = idx_2d[0]
    
    # The mask is used to ignore the edge values (or indeed any values).
    mask = numpy.ones(im.shape, dtype='bool')
    mask[0, :] = False
    mask[:, 0] = False
    mask[im.shape[0] - 1, :] = False
    mask[:, im.shape[1] - 1] = False
    
    # We create and fill an array that contains the lookup for every
    # possible 3x3 array.
    idx_array = numpy.zeros((im[mask].size, 3, 3), dtype='int64')
    
    # Compute the flattened indices for each position in the 3x3 grid
    for n in range(0, 3):
        for m in range(0, 3):
            # Compute the flattened indices for each position in the 
            # 3x3 grid
            idx = (x_idx + (n-1)) + (y_idx  + (m-1)) * im.shape[1]
    
            # mask it, and write it to the big array
            idx_array[:, m, n] = idx[mask]
    
    
    # sub_images contains every valid 3x3 sub image
    sub_images = im.ravel()[idx_array]
    
    # Finally, we can flatten and sort each sub array quickly
    sorted_sub_images = numpy.sort(sub_images.reshape((idx[mask].size, 9)))
    
    0 讨论(0)
  • 2020-12-09 14:27

    This can be done faster, by reshaping and swapping axes, and then repeating over all kernel elements, like this:

    im = np.arange(81).reshape(9,9)
    print np.swapaxes(im.reshape(3,3,3,-1),1,2)
    

    This gives you an array of 3*3 tiles which tessalates across the surface:

    [[[[ 0  1  2]   [[ 3  4  5]   [[ 6  7  8]
       [ 9 10 11]    [12 13 14]    [15 16 17]
       [18 19 20]]   [21 22 23]]   [24 25 26]]]
    
     [[[27 28 29]   [[30 31 32]   [[33 34 35]
       [36 37 38]    [39 40 41]    [42 43 44]
       [45 46 47]]   [48 49 50]]   [51 52 53]]]
    
     [[[54 55 56]   [[57 58 59]   [[60 61 62]
       [63 64 65]    [66 67 68]    [69 70 71]
       [72 73 74]]   [75 76 77]]   [78 79 80]]]]
    

    To get the overlapping tiles we need to repeat this 8 further times, but 'wrapping' the array, by using a combination of vstack and column_stack. Note that the right and bottom tile arrays wrap around (which may or may not be what you want, depending on how you are treating edge conditions):

    im =  np.vstack((im[1:],im[0]))
    im =  np.column_stack((im[:,1:],im[:,0]))
    print np.swapaxes(im.reshape(3,3,3,-1),1,2)
    
    #Output:
    [[[[10 11 12]   [[13 14 15]   [[16 17  9]
       [19 20 21]    [22 23 24]    [25 26 18]
       [28 29 30]]   [31 32 33]]   [34 35 27]]]
    
     [[[37 38 39]   [[40 41 42]   [[43 44 36]
       [46 47 48]    [49 50 51]    [52 53 45]
       [55 56 57]]   [58 59 60]]   [61 62 54]]]
    
     [[[64 65 66]   [[67 68 69]   [[70 71 63]
       [73 74 75]    [76 77 78]    [79 80 72]
       [ 1  2  3]]   [ 4  5  6]]   [ 7  8  0]]]]
    

    Doing it this way you wind up with 9 sets of arrays, so you then need to zip them back together. This, and all the reshaping generalises to this (for arrays where the dimensions are divisible by 3):

    def new(im):
        rows,cols = im.shape
        final = np.zeros((rows, cols, 3, 3))
        for x in (0,1,2):
            for y in (0,1,2):
                im1 = np.vstack((im[x:],im[:x]))
                im1 = np.column_stack((im1[:,y:],im1[:,:y]))
                final[x::3,y::3] = np.swapaxes(im1.reshape(rows/3,3,cols/3,-1),1,2)
        return final
    

    Comparing this new function to looping through all the slices (below), using timeit, its about 4 times faster, for a 300*300 array.

    def old(im):
        rows,cols = im.shape
        s = []
        for x in xrange(1,rows):
            for y in xrange(1,cols):
                s.append(im[x-1:x+2,y-1:y+2])
        return s
    
    0 讨论(0)
  • 2020-12-09 14:31

    Try the following code as matlab function im2col(...)

    import numpy as np
    
    def im2col(Im, block, style='sliding'):
        """block = (patchsize, patchsize)
            first do sliding
        """
        bx, by = block
        Imx, Imy = Im.shape
        Imcol = []
        for j in range(0, Imy):
            for i in range(0, Imx):
                if (i+bx <= Imx) and (j+by <= Imy):
                    Imcol.append(Im[i:i+bx, j:j+by].T.reshape(bx*by))
                else:
                    break
        return np.asarray(Imcol).T
    
    if __name__ == '__main__':
        Im = np.reshape(range(6*6), (6,6))
        patchsize = 3
        print Im
        out =  im2col(Im, (patchsize, patchsize))
        print out
        print out.shape
        print len(out)
    
    0 讨论(0)
提交回复
热议问题