nearest neighbour search kdTree

前端 未结 4 802
礼貌的吻别
礼貌的吻别 2021-01-20 17:51

To a list of N points [(x_1,y_1), (x_2,y_2), ... ] I am trying to find the nearest neighbours to each point based on distance. My dataset is too la

4条回答
  •  逝去的感伤
    2021-01-20 18:20

    I implemented the solution to this problem and i think it might be helpful.

    from collections import namedtuple
    from operator import itemgetter
    from pprint import pformat
    from math import inf
    
    
    def nested_getter(idx1, idx2):
        def g(obj):
            return obj[idx1][idx2]
        return g
    
    
    class Node(namedtuple('Node', 'location left_child right_child')):
        def __repr__(self):
            return pformat(tuple(self))
    
    
    def kdtree(point_list, depth: int = 0):
        if not point_list:
            return None
    
        k = len(point_list[0])  # assumes all points have the same dimension
        # Select axis based on depth so that axis cycles through all valid values
        axis = depth % k
    
        # Sort point list by axis and choose median as pivot element
        point_list.sort(key=nested_getter(1, axis))
        median = len(point_list) // 2
    
        # Create node and construct subtrees
        return Node(
            location=point_list[median],
            left_child=kdtree(point_list[:median], depth + 1),
            right_child=kdtree(point_list[median + 1:], depth + 1)
        )
    
    
    def nns(q, n, p, w, depth: int = 0):
        """
        NNS = Nearest Neighbor Search
        :param depth:
        :param q: point
        :param n: node
        :param p: ref point
        :param w: ref distance
        :return: new_p, new_w
        """
    
        new_w = distance(q[1], n.location[1])
        # below we test if new_w > 0 because we don't want to allow p = q
        if (new_w > 0) and new_w < w:
            p, w = n.location, new_w
    
        k = len(p)
        axis = depth % k
        n_value = n.location[1][axis]
        search_left_first = (q[1][axis] <= n_value)
        if search_left_first:
            if n.left_child and (q[1][axis] - w <= n_value):
                new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
            if n.right_child and (q[1][axis] + w >= n_value):
                new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
        else:
            if n.right_child and (q[1][axis] + w >= n_value):
                new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
            if n.left_child and (q[1][axis] - w <= n_value):
                new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
                if new_w < w:
                    p, w = new_p, new_w
        return p, w
    
    
    def main():
        """Example usage of kdtree"""
        point_list = [(7, 2), (5, 4), (9, 6), (4, 7), (8, 1), (2, 3)]
        tree = kdtree(point_list)
        print(tree)
    
    
    def city_houses():
        """
        Here we compute the distance to the nearest city from a list of N cities.
        The first line of input contains N, the number of cities.
        Each of the next N lines contain two integers x and y, which locate the city in (x,y),
        separated by a single whitespace.
        It's guaranteed that a spot (x,y) does not contain more than one city.
        The output contains N lines, the line i with a number representing the distance
        for the nearest city from the i-th city of the input.
        """
        n = int(input())
        cities = []
        for i in range(n):
            city = i, tuple(map(int, input().split(' ')))
            cities.append(city)
        # print(cities)
        tree = kdtree(cities)
        # print(tree)
        ans = [(target[0], nns(target, tree, tree.location, inf)[1]) for target in cities]
        ans.sort(key=itemgetter(0))
        ans = [item[1] for item in ans]
        print('\n'.join(map(str, ans)))
    
    
    def distance(a, b):
        # Taxicab distance is used below. You can use squared euclidean distance if you prefer
        k = len(b)
        total = 0
        for i in range(k):
            total += abs(b[i] - a[i])
        return total
    
    
    if __name__ == '__main__':
        city_houses()
    

提交回复
热议问题