Efficiently find indices of nearest points on non-rectangular 2D grid

倖福魔咒の 提交于 2020-01-14 10:16:05

问题


I have an irregular (non-rectangular) lon/lat grid and a bunch of points in lon/lat coordinates, which should correspond to points on the grid (though they might be slightly off for numerical reasons). Now I need the indices of the corresponding lon/lat points.

I've written a function which does this, but it is REALLY slow.

def find_indices(lon,lat,x,y):
    lonlat = np.dstack([lon,lat])
    delta = np.abs(lonlat-[x,y])
    ij_1d = np.linalg.norm(delta,axis=2).argmin()
    i,j = np.unravel_index(ij_1d,lon.shape)
    return i,j

ind = [find_indices(lon,lat,p*) for p in points]

I'm pretty sure there's a better (and faster) solution in numpy/scipy. I've already googled quite a lot, but the answer has so far eluded me.

Any suggestions how to efficiently find the indices of the corresponding (nearest) points?

PS: This question emerged from another one (click).

Edit: Solution

Based on @Cong Ma's answer, I've found the following solution:

def find_indices(points,lon,lat,tree=None):
    if tree is None:
        lon,lat = lon.T,lat.T
        lonlat = np.column_stack((lon.ravel(),lat.ravel()))
        tree = sp.spatial.cKDTree(lonlat)
    dist,idx = tree.query(points,k=1)
    ind = np.column_stack(np.unravel_index(idx,lon.shape))
    return [(i,j) for i,j in ind]

To put this solution and also the one from Divakar's answer into perspective, here are some timings of the function in which I'm using find_indices (and where it's the bottleneck in terms of speed) (see link above):

spatial_contour_frequency/pil0                :   331.9553
spatial_contour_frequency/pil1                :   104.5771
spatial_contour_frequency/pil2                :     2.3629
spatial_contour_frequency/pil3                :     0.3287

pil0 is my initial approach, pil1 Divakar's, and pil2/pil3 the final solution above, where the tree is created on-the-fly in pil2 (i.e. for every iteration of the loop in which find_indices is called) and only once in pil3 (see other thread for details). Even though Divakar's refinement of my initial approach gives me a 3x speed-up, cKDTree takes this to a whole new level with another 50x speedup! And moving the creation of the tree out of the function makes things even faster.


回答1:


If the points are sufficiently localized, you may try directly scipy.spatial's cKDTree implementation, as discussed by myself in another post. That post was about interpolation but you can ignore that and just use the query part.

tl;dr version:

Read up the documentation of scipy.sptial.cKDTree. Create the tree by passing an (n, m)-shaped numpy ndarray object to the initializer, and the tree will be created from the n m-dimensional coordinates.

tree = scipy.spatial.cKDTree(array_of_coordinates)

After that, use tree.query() to retrieve the k-th nearest neighbor (possibly with approximation and parallelization, see docs), or use tree.query_ball_point() to find all neighbors within given distance tolerance.

If the points are not well localized, and the spherical curvature / non-trivial topology kicks in, you can try breaking the manifold into multiple parts, each small enough to be considered local.




回答2:


Here's a generic vectorized approach using scipy.spatial.distance.cdist -

import scipy

# Stack lon and lat arrays as columns to form a Nx2 array, where is N is grid**2
lonlat = np.column_stack((lon.ravel(),lat.ravel()))

# Get the distances and get the argmin across the entire N length
idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)

# Get the indices corresponding to grid's shape as the final output
ind = np.column_stack((np.unravel_index(idx,lon.shape))).tolist()

Sample run -

In [161]: lon
Out[161]: 
array([[-11.   ,  -7.82 ,  -4.52 ,  -1.18 ,   2.19 ],
       [-12.   ,  -8.65 ,  -5.21 ,  -1.71 ,   1.81 ],
       [-13.   ,  -9.53 ,  -5.94 ,  -2.29 ,   1.41 ],
       [-14.1  ,  -0.04 ,  -6.74 ,  -2.91 ,   0.976]])

In [162]: lat
Out[162]: 
array([[-11.2  ,  -7.82 ,  -4.51 ,  -1.18 ,   2.19 ],
       [-12.   ,  -8.63 ,  -5.27 ,  -1.71 ,   1.81 ],
       [-13.2  ,  -9.52 ,  -5.96 ,  -2.29 ,   1.41 ],
       [-14.3  ,  -0.06 ,  -6.75 ,  -2.91 ,   0.973]])

In [163]: lonlat = np.column_stack((lon.ravel(),lat.ravel()))

In [164]: idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)

In [165]: np.column_stack((np.unravel_index(idx,lon.shape))).tolist()
Out[165]: [[0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [3, 3]]

Runtime tests -

Define functions:

def find_indices(lon,lat,x,y):
    lonlat = np.dstack([lon,lat])
    delta = np.abs(lonlat-[x,y])
    ij_1d = np.linalg.norm(delta,axis=2).argmin()
    i,j = np.unravel_index(ij_1d,lon.shape)
    return i,j

def loopy_app(lon,lat,pts):
    return [find_indices(lon,lat,pts[i,0],pts[i,1]) for i in range(pts.shape[0])]

def vectorized_app(lon,lat,points):
    lonlat = np.column_stack((lon.ravel(),lat.ravel()))
    idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)
    return np.column_stack((np.unravel_index(idx,lon.shape))).tolist()

Timings:

In [179]: lon = np.random.rand(100,100)

In [180]: lat = np.random.rand(100,100)

In [181]: points = np.random.rand(50,2)

In [182]: %timeit loopy_app(lon,lat,points)
10 loops, best of 3: 47 ms per loop

In [183]: %timeit vectorized_app(lon,lat,points)
10 loops, best of 3: 16.6 ms per loop

For squeezing out more performance, np.concatenate could be used in place of np.column_stack.



来源:https://stackoverflow.com/questions/32909087/efficiently-find-indices-of-nearest-points-on-non-rectangular-2d-grid

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!