Retrieve position of elements with setting some criteria in numpy

后端 未结 4 879
没有蜡笔的小新
没有蜡笔的小新 2021-01-13 22:24

For the given 2d array of data, how to retrieve the position (index) of 7 and 11 in the bold. Because only they are the elements surrounded by same value in the neighbours

相关标签:
4条回答
  • 2021-01-13 22:33

    Using scipy, you could characterize such points as those which are both the maximum and the minimum of its neighborhood:

    import numpy as np
    import scipy.ndimage.filters as filters
    
    def using_filters(data):
        return np.where(np.logical_and.reduce(
            [data == f(data, footprint=np.ones((3,3)), mode='constant', cval=np.inf)
             for f in (filters.maximum_filter, filters.minimum_filter)]))  
    
    using_filters(data)
    # (array([2, 3]), array([5, 9]))
    

    Using only numpy, you could compare data with 8 shifted slices of itself to find the points which are equal:

    def using_eight_shifts(data):
        h, w = data.shape
        data2 = np.empty((h+2, w+2))
        data2[(0,-1),:] = np.nan
        data2[:,(0,-1)] = np.nan
        data2[1:1+h,1:1+w] = data
    
        result = np.where(np.logical_and.reduce([
            (data2[i:i+h,j:j+w] == data)
            for i in range(3)
            for j in range(3)
            if not (i==1 and j==1)]))
        return result
    

    As you can see above, this strategy makes an expanded array which has a border of NaNs around the data. This allows the shifted slices to be expressed as data2[i:i+h,j:j+w].

    If you know that you are going to be comparing against neighbors, it might behoove you to define data with a border of NaNs from the very beginning so you don't have to make a second array as done above.

    Using eight shifts (and comparisons) is much faster than looping over each cell in data and comparing it against its neighbors:

    def using_quadratic_loop(data):
        return np.array([[i,j]
                for i in range(1,np.shape(data)[0]-1)
                for j in range(1,np.shape(data)[1]-1)
                if np.all(data[i-1:i+2,j-1:j+2]==data[i,j])]).T
    

    Here is a benchmark:

    using_filters            : 0.130
    using_eight_shifts       : 0.340
    using_quadratic_loop     : 18.794
    

    Here is the code used to produce the benchmark:

    import timeit
    import operator
    import numpy as np
    import scipy.ndimage.filters as filters
    import matplotlib.pyplot as plt
    
    data  = np.array([
        [0,1,2,3,4,7,6,7,8,9,10], 
        [3,3,3,4,7,7,7,8,11,12,11],  
        [3,3,3,5,7,7,7,9,11,11,11],
        [3,4,3,6,7,7,7,10,11,11,11],
        [4,5,6,7,7,9,10,11,11,11,11]
        ])
    
    data = np.tile(data, (50,50))
    
    def using_filters(data):
        return np.where(np.logical_and.reduce(
            [data == f(data, footprint=np.ones((3,3)), mode='constant', cval=np.inf)
             for f in (filters.maximum_filter, filters.minimum_filter)]))    
    
    
    def using_eight_shifts(data):
        h, w = data.shape
        data2 = np.empty((h+2, w+2))
        data2[(0,-1),:] = np.nan
        data2[:,(0,-1)] = np.nan
        data2[1:1+h,1:1+w] = data
    
        result = np.where(np.logical_and.reduce([
            (data2[i:i+h,j:j+w] == data)
            for i in range(3)
            for j in range(3)
            if not (i==1 and j==1)]))
        return result
    
    
    def using_quadratic_loop(data):
        return np.array([[i,j]
                for i in range(1,np.shape(data)[0]-1)
                for j in range(1,np.shape(data)[1]-1)
                if np.all(data[i-1:i+2,j-1:j+2]==data[i,j])]).T
    
    np.testing.assert_equal(using_quadratic_loop(data), using_filters(data))
    np.testing.assert_equal(using_eight_shifts(data), using_filters(data))
    
    timing = dict()
    for f in ('using_filters', 'using_eight_shifts', 'using_quadratic_loop'):
        timing[f] = timeit.timeit('{f}(data)'.format(f=f),
                                  'from __main__ import data, {f}'.format(f=f),
                                  number=10) 
    
    for f, t in sorted(timing.items(), key=operator.itemgetter(1)):
        print('{f:25}: {t:.3f}'.format(f=f, t=t))
    
    0 讨论(0)
  • 2021-01-13 22:36

    I have tested it and the following code works:

    for i in range(1,np.shape(data)[0]-1):
        for j in range(1,np.shape(data)[1]-1):
            if np.all(data[i-1:i+2,j-1:j+2]==data[i,j]):
                print np.array([i,j], dtype=np.int64)
    
    0 讨论(0)
  • 2021-01-13 22:40

    I used a list comprehension but there may be a better way

    A = [(i,j) for i in range(1,data.shape[0]-1) for j in range(1,data.shape[1]-1) if all((data[i-1:i+2,j-1:j+2]==data[i,j]).flatten())]
    

    EDIT:

    If you want the form array([i,j],dtype=int64) then you just need to modify the first part:

    A= [np.array([i,j], dtype=np.int64) for i in range(1,data.shape[0]-1) for j in range(1,data.shape[1]-1) if all((data[i-1:i+2,j-1:j+2]==data[i,j]).flatten())]
    
    0 讨论(0)
  • 2021-01-13 22:40
    displacements = [[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 1], [1, -1], [1, 0], [1, 1]]
    
    for x in range(1, data.shape[0] - 1):
        for y in range(1, data.shape[1] - 1):
            if all((data[x][y] == data[x + a][y + b]) for a, b in displacements):
                print np.array([x, y], dtype=np.int64)
    

    Not as succint as the other answers, but it's clear and prints the correct output. I think it's also a little easier to change/add displacement values.

    Whoops, didn't realize you wanted all 8 neighbors. Easy fix though. :)

    0 讨论(0)
提交回复
热议问题