To a list of N
points [(x_1,y_1), (x_2,y_2), ... ]
I am trying to find the nearest neighbours to each point based on distance. My dataset is too la
I implemented the solution to this problem and i think it might be helpful.
from collections import namedtuple
from operator import itemgetter
from pprint import pformat
from math import inf
def nested_getter(idx1, idx2):
def g(obj):
return obj[idx1][idx2]
return g
class Node(namedtuple('Node', 'location left_child right_child')):
def __repr__(self):
return pformat(tuple(self))
def kdtree(point_list, depth: int = 0):
if not point_list:
return None
k = len(point_list[0]) # assumes all points have the same dimension
# Select axis based on depth so that axis cycles through all valid values
axis = depth % k
# Sort point list by axis and choose median as pivot element
point_list.sort(key=nested_getter(1, axis))
median = len(point_list) // 2
# Create node and construct subtrees
return Node(
location=point_list[median],
left_child=kdtree(point_list[:median], depth + 1),
right_child=kdtree(point_list[median + 1:], depth + 1)
)
def nns(q, n, p, w, depth: int = 0):
"""
NNS = Nearest Neighbor Search
:param depth:
:param q: point
:param n: node
:param p: ref point
:param w: ref distance
:return: new_p, new_w
"""
new_w = distance(q[1], n.location[1])
# below we test if new_w > 0 because we don't want to allow p = q
if (new_w > 0) and new_w < w:
p, w = n.location, new_w
k = len(p)
axis = depth % k
n_value = n.location[1][axis]
search_left_first = (q[1][axis] <= n_value)
if search_left_first:
if n.left_child and (q[1][axis] - w <= n_value):
new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
if n.right_child and (q[1][axis] + w >= n_value):
new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
else:
if n.right_child and (q[1][axis] + w >= n_value):
new_p, new_w = nns(q, n.right_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
if n.left_child and (q[1][axis] - w <= n_value):
new_p, new_w = nns(q, n.left_child, p, w, depth + 1)
if new_w < w:
p, w = new_p, new_w
return p, w
def main():
"""Example usage of kdtree"""
point_list = [(7, 2), (5, 4), (9, 6), (4, 7), (8, 1), (2, 3)]
tree = kdtree(point_list)
print(tree)
def city_houses():
"""
Here we compute the distance to the nearest city from a list of N cities.
The first line of input contains N, the number of cities.
Each of the next N lines contain two integers x and y, which locate the city in (x,y),
separated by a single whitespace.
It's guaranteed that a spot (x,y) does not contain more than one city.
The output contains N lines, the line i with a number representing the distance
for the nearest city from the i-th city of the input.
"""
n = int(input())
cities = []
for i in range(n):
city = i, tuple(map(int, input().split(' ')))
cities.append(city)
# print(cities)
tree = kdtree(cities)
# print(tree)
ans = [(target[0], nns(target, tree, tree.location, inf)[1]) for target in cities]
ans.sort(key=itemgetter(0))
ans = [item[1] for item in ans]
print('\n'.join(map(str, ans)))
def distance(a, b):
# Taxicab distance is used below. You can use squared euclidean distance if you prefer
k = len(b)
total = 0
for i in range(k):
total += abs(b[i] - a[i])
return total
if __name__ == '__main__':
city_houses()