nearest neighbour search kdTree

前端 未结 4 801
礼貌的吻别
礼貌的吻别 2021-01-20 17:51

To a list of N points [(x_1,y_1), (x_2,y_2), ... ] I am trying to find the nearest neighbours to each point based on distance. My dataset is too la

相关标签:
4条回答
  • 2021-01-20 18:20

    I implemented the solution to this problem and i think it might be helpful.

    from collections import namedtuple
    from operator import itemgetter
    from pprint import pformat
    from math import inf
    
    
    def nested_getter(idx1, idx2):
        def g(obj):
            return obj[idx1][idx2]
        return g
    
    
    class Node(namedtuple('Node', 'location left_child right_child')):
        def __repr__(self):
            return pformat(tuple(self))
    
    
    def kdtree(point_list, depth: int = 0):
        if not point_list:
            return None
    
        k = len(point_list[0])  # assumes all points have the same dimension
        # Select axis based on depth so that axis cycles through all valid values
        axis = depth % k
    
        # Sort point list by axis and choose median as pivot element
        point_list.sort(key=nested_getter(1, axis))
        median = len(point_list) // 2
    
        # Create node and construct subtrees
        return Node(
            location=point_list[median],
            left_child=kdtree(point_list[:median], depth + 1),
            right_child=kdtree(point_list[median + 1:], depth + 1)
        )
    
    
    def nns(q, n, p, w, depth: int = 0):
        """
        NNS = Nearest Neighbor Search
        :param depth:
        :param q: point
        :param n: node
        :param p: ref point
        :param w: ref distance
        :return: new_p, new_w
        """
    
        new_w = distance(q[1], n.location[1])
        # below we test if new_w > 0 because we don't want to allow p = q
        if (new_w > 0) and new_w < w:
            p, w = n.location, new_w
    
        k = len(p)
        axis = depth % k
        n_value = n.location[1][axis]
        search_left_first = (q[1][axis] <= n_value)
        if search_left_first:
            if n.left_child and (q[1][axis] - w <= n_value):
                new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
            if n.right_child and (q[1][axis] + w >= n_value):
                new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
        else:
            if n.right_child and (q[1][axis] + w >= n_value):
                new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
            if n.left_child and (q[1][axis] - w <= n_value):
                new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
        return p, w
    
    
    def main():
        """Example usage of kdtree"""
        point_list = [(7, 2), (5, 4), (9, 6), (4, 7), (8, 1), (2, 3)]
        tree = kdtree(point_list)
        print(tree)
    
    
    def city_houses():
        """
        Here we compute the distance to the nearest city from a list of N cities.
        The first line of input contains N, the number of cities.
        Each of the next N lines contain two integers x and y, which locate the city in (x,y),
        separated by a single whitespace.
        It's guaranteed that a spot (x,y) does not contain more than one city.
        The output contains N lines, the line i with a number representing the distance
        for the nearest city from the i-th city of the input.
        """
        n = int(input())
        cities = []
        for i in range(n):
            city = i, tuple(map(int, input().split(' ')))
            cities.append(city)
        # print(cities)
        tree = kdtree(cities)
        # print(tree)
        ans = [(target[0], nns(target, tree, tree.location, inf)[1]) for target in cities]
        ans.sort(key=itemgetter(0))
        ans = [item[1] for item in ans]
        print('\n'.join(map(str, ans)))
    
    
    def distance(a, b):
        # Taxicab distance is used below. You can use squared euclidean distance if you prefer
        k = len(b)
        total = 0
        for i in range(k):
            total += abs(b[i] - a[i])
        return total
    
    
    if __name__ == '__main__':
        city_houses()
    
    0 讨论(0)
  • 2021-01-20 18:26

    You can use sklearn.neighbors.KDTree's query_radius() method, which returns a list of the indices of the nearest neighbours within some radius (as opposed to returning k nearest neighbours).

    from sklearn.neighbors import KDTree
    
    points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
    
    tree = KDTree(points, leaf_size=2)
    all_nn_indices = tree.query_radius(points, r=1.5)  # NNs within distance of 1.5 of point
    all_nns = [[points[idx] for idx in nn_indices] for nn_indices in all_nn_indices]
    for nns in all_nns:
        print(nns)
    

    Outputs:

    [(1, 1), (2, 2)]
    [(1, 1), (2, 2), (3, 3)]
    [(2, 2), (3, 3), (4, 4)]
    [(3, 3), (4, 4), (5, 5)]
    [(4, 4), (5, 5)]
    

    Note that each point includes itself in its list of nearest neighbours within the given radius. If you want to remove these identity points, the line computing all_nns can be changed to:

    all_nns = [
        [points[idx] for idx in nn_indices if idx != i]
        for i, nn_indices in enumerate(all_nn_indices)
    ]
    

    Resulting in:

    [(2, 2)]
    [(1, 1), (3, 3)]
    [(2, 2), (4, 4)]
    [(3, 3), (5, 5)]
    [(4, 4)]
    
    0 讨论(0)
  • 2021-01-20 18:31

    This question is very broad and missing details. It's unclear what you did try, how your data looks like and what a nearest-neighbor is (identity?).

    Assuming you are not interested in the identity (with distance 0), you can query the two nearest-neighbors and drop the first column. This is probably the easiest approach here.

    Code:

     import numpy as np
     from sklearn.neighbors import KDTree
     np.random.seed(0)
     X = np.random.random((5, 2))  # 5 points in 2 dimensions
     tree = KDTree(X)
     nearest_dist, nearest_ind = tree.query(X, k=2)  # k=2 nearest neighbors where k1 = identity
     print(X)
     print(nearest_dist[:, 1])    # drop id; assumes sorted -> see args!
     print(nearest_ind[:, 1])     # drop id 
    

    Output

     [[ 0.5488135   0.71518937]
      [ 0.60276338  0.54488318]
      [ 0.4236548   0.64589411]
      [ 0.43758721  0.891773  ]
      [ 0.96366276  0.38344152]]
     [ 0.14306129  0.1786471   0.14306129  0.20869372  0.39536284]
     [2 0 0 0 1]
    
    0 讨论(0)
  • 2021-01-20 18:34

    The sklearn should be the best. I wrote the below some time back ,where I needed custom distance. (I guess sklearn does not support custom distance fn 'KD tree' with custom distance metric . Adding for reference

    Adapted from my gist for 2D https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8

    # From https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8
    # Author alex punnen
    from collections import namedtuple
    from operator import itemgetter
    import numpy as np
    
    def find_nearest_neighbour(node,point,distance_fn,current_axis):
        # Algorith to find nearest neighbour in a KD Tree;the KD tree has done a spatial sort
        # of the given co-ordinates, such that to the left of the root lies co-ordinates nearest to the x-axis
        # and to the right of the root ,lies the co-ordinates farthest from the x axis
        # On the y axis split on the left of the parent/root node lies co-ordinates nearest to the y-axis and to
        # the right of the root, lies the co-ordinates farthest from the y axis
        # to find the nearest neightbour, from the root, you first check left and right node; if distance is closer
        # to the right node,then the entire left node can be discarded from search, because of the spatial split
        # and that node becomes the root node. This process is continued recursively till the nearest is found
        # param:node: The current node
        # param: point: The point to which the nearest neighbour is to be found
        # param: distance_fn: to calculate the nearest neighbour
        # param: current_axis: here assuming only two dimenstion and current axis will be either x or y , 0 or 1
    
        if node is None:
            return None,None
        current_closest_node = node
        closest_known_distance = distance_fn(node.cell[0],node.cell[1],point[0],point[1])
        print closest_known_distance,node.cell
    
        x = (node.cell[0],node.cell[1])
        y = point
    
        new_node = None
        new_closest_distance = None
        if x[current_axis] > y[current_axis]:
            new_node,new_closest_distance= find_nearest_neighbour(node.left_branch,point,distance_fn,
                                                              (current_axis+1) %2)
        else:
            new_node,new_closest_distance = find_nearest_neighbour(node.right_branch,point,distance_fn,
                                                               (current_axis+1) %2) 
    
        if  new_closest_distance and new_closest_distance < closest_known_distance:
            print 'Reset closest node to ',new_node.cell
            closest_known_distance = new_closest_distance
            current_closest_node = new_node
    
        return current_closest_node,closest_known_distance
    
    
    class Node(namedtuple('Node','cell, left_branch, right_branch')):
        # This Class is taken from wikipedia code snippet for  KD tree
        pass
    
    def create_kdtree(cell_list,current_axis,no_of_axis):
        # Creates a KD Tree recursively following the snippet from wikipedia for KD tree
        # but making it generic for any number of axis and changes in data strucure
        if not cell_list:
            return
        # get the cell as a tuple list this is for 2 dimensions
        k= [(cell[0],cell[1])  for cell  in cell_list]
        # say for three dimension
        # k= [(cell[0],cell[1],cell[2])  for cell  in cell_list]
        k.sort(key=itemgetter(current_axis)) # sort on the current axis
        median = len(k) // 2 # get the median of the list
        axis = (current_axis + 1) % no_of_axis # cycle the axis
        return Node(k[median], # recurse 
                    create_kdtree(k[:median],axis,no_of_axis),
                    create_kdtree(k[median+1:],axis,no_of_axis))
    
    def eucleaden_dist(x1,y1,x2,y2):
        a= np.array([x1,y1])
        b= np.array([x2,y2])
        dist = np.linalg.norm(a-b)
        return dist
    
    
    np.random.seed(0)
    #cell_list = np.random.random((2, 2))
    #cell_list = cell_list.tolist()
    cell_list = [[2,2],[4,8],[10,2]]
    print(cell_list)
    tree = create_kdtree(cell_list,0,2)
    
    node,distance = find_nearest_neighbour(tree,(1, 1),eucleaden_dist,0)
    print 'Nearest Neighbour=',node.cell,distance
    
    node,distance = find_nearest_neighbour(tree,(8, 1),eucleaden_dist,0)
    print 'Nearest Neighbour=',node.cell,distance
    
    0 讨论(0)
提交回复
热议问题