Closest Pair Implemetation Python

前端 未结 4 1981
时光说笑
时光说笑 2021-02-04 17:32

I am trying to implement the closest pair problem in Python using divide and conquer, everything seems to work fine except that in some input cases, there is a wrong answer. My

相关标签:
4条回答
  • 2021-02-04 18:10

    You just need to change the seventh line in your closestSplitPair function def from best=(Sy[i],Sy[i+j]) to best=dist(Sy[i],Sy[i+j]) and you will get the correct answer: ((94, 5), (99, -8), 13.92838827718412). You were missing the calling to the dist function.

    This was pointed out by Padraic Cunningham's answer as the first problem.

    Best Regards.

    0 讨论(0)
  • 2021-02-04 18:14

    Brute force can work faster with stdlib functions. Therefore, it can be effectively applied to more than 3 points.

    from itertools import combinations
    
    def closest(points_list):
        return min((dist(p1, p2), p1, p2)
                   for p1, p2 in combinations(points_list, r=2))
    

    The most effective way to divide the points is to divide them on tiles. If you don't have outliers, you can just split your space on equal parts and compare points only in the same or in the neighbour tiles. Number of tiles must be as large as it possible. But, to avoid isolated tiles, when each point doesn't have points in neighbour tiles, you must limit number of tiles by the number of points. Full listing:

    from math import sqrt
    from itertools import combinations, product
    from collections import defaultdict
    import sys
    
    max_float = sys.float_info.max
    
    def dist((x1, y1), (x2, y2)):
        return sqrt((x1 - x2) ** 2 + (y1 - y2) **2)
    
    def closest(points_list):
        if len(points_list) < 2:
            return (max_float, None, None)  # default value compatible with min function
        return min((dist(p1, p2), p1, p2)
                   for p1, p2 in combinations(points_list, r=2))
    
    def closest_between(pnt_lst1, pnt_lst2):
        if not pnt_lst1 or not pnt_lst2:
            return (max_float, None, None)  # default value compatible with min function
        return min((dist(p1, p2), p1, p2)
                   for p1, p2 in product(pnt_lst1, pnt_lst2))
    
    def divide_on_tiles(points_list):
        side = int(sqrt(len(points_list)))  # number of tiles on one side of square
        tiles = defaultdict(list)
        min_x = min(x for x, y in points_list)
        max_x = max(x for x, y in points_list)
        min_y = min(x for x, y in points_list)
        max_y = max(x for x, y in points_list)
        tile_x_size = float(max_x - min_x) / side
        tile_y_size = float(max_y - min_y) / side
        for x, y in points_list:
            x_tile = int((x - min_x) / tile_x_size)
            y_tile = int((y - min_y) / tile_y_size)
            tiles[(x_tile, y_tile)].append((x, y))
        return tiles
    
    def closest_for_tile(tiles, (x_tile, y_tile)):
        points = tiles[(x_tile, y_tile)]
        return min(closest(points),
                   # use dict.get to avoid creating empty tiles
                   # we compare current tile only with half of neighbours (right/top),
                   # because another half (left/bottom) make it in another iteration by themselves
                   closest_between(points, tiles.get((x_tile+1, y_tile))),
                   closest_between(points, tiles.get((x_tile, y_tile+1))),
                   closest_between(points, tiles.get((x_tile+1, y_tile+1))),
                   closest_between(points, tiles.get((x_tile-1, y_tile+1))))
    
    def find_closest_in_tiles(tiles):
        return min(closest_for_tile(tiles, coord) for coord in tiles.keys())
    
    
    P1 = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)]
    P2 = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)]
    
    print find_closest_in_tiles(divide_on_tiles(P1)) # (2.8284271247461903, (7, 6), (5, 8))
    print find_closest_in_tiles(divide_on_tiles(P2)) # (13.92838827718412, (94, 5), (99, -8))
    print find_closest_in_tiles(divide_on_tiles(P1 + P2)) # (2.8284271247461903, (7, 6), (5, 8))
    
    0 讨论(0)
  • 2021-02-04 18:15

    You have two problems, you are forgetting to call dist to update the best distance. But the main problem is there is more than one recursive call happening so you can end up overwriting when you find a closer split pair with the default, best,p3,q3 = d,None,None. I passed the best pair from closest_pair as an argument to closest_split_pair so I would not potentially overwrite the value.

    def closest_split_pair(p_x, p_y, delta, best_pair): # <- a parameter
        ln_x = len(p_x)
        mx_x = p_x[ln_x // 2][0]
        s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
        best = delta
        for i in range(len(s_y) - 1):
            for j in range(1, min(i + 7, (len(s_y) - i))):
                p, q = s_y[i], s_y[i + j]
                dst = dist(p, q)
                if dst < best:
                    best_pair = p, q
                    best = dst
        return best_pair
    

    The end of closest_pair looks like the following:

        p_1, q_1 = closest_pair(srt_q_x, srt_q_y)
        p_2, q_2 = closest_pair(srt_r_x, srt_r_y)
        closest = min(dist(p_1, q_1), dist(p_2, q_2))
        # get min of both and then pass that as an arg to closest_split_pair
        mn = min((p_1, q_1), (p_2, q_2), key=lambda x: dist(x[0], x[1]))
        p_3, q_3 = closest_split_pair(p_x, p_y, closest,mn)
        # either return mn or we have a closer split pair
        return min(mn, (p_3, q_3), key=lambda x: dist(x[0], x[1]))
    

    You also have some other logic issues, your slicing logic is not correct, I made some changes to your code where brute is just a simple bruteforce double loop:

    def closestPair(Px, Py):
        if len(Px) <= 3:
            return brute(Px)
    
        mid = len(Px) / 2
        # get left and right half of Px 
        q, r = Px[:mid], Px[mid:]
         # sorted versions of q and r by their x and y coordinates 
        Qx, Qy = [x for x in q if Py and  x[0] <= Px[-1][0]], [x for x in q if x[1] <= Py[-1][1]]
        Rx, Ry = [x for x in r if Py and x[0] <= Px[-1][0]], [x for x in r if x[1] <= Py[-1][1]]
        (p1, q1) = closestPair(Qx, Qy)
        (p2, q2) = closestPair(Rx, Ry)
        d = min(dist(p1, p2), dist(p2, q2))
        mn = min((p1, q1), (p2, q2), key=lambda x: dist(x[0], x[1]))
        (p3, q3) = closest_split_pair(Px, Py, d, mn)
        return min(mn, (p3, q3), key=lambda x: dist(x[0], x[1]))
    

    I just did the algorithm today so there are no doubt some improvements to be made but this will get you the correct answer.

    0 讨论(0)
  • 2021-02-04 18:29

    Here is a recursive divide-and-conquer python implementation of the closest point problem based on the heap data structure. It also accounts for the negative integers. It can return the k-closest point by popping k nodes in the heap using heappop().

    from __future__ import division
    from collections import namedtuple
    from random import randint
    import math as m
    import heapq as hq
    
    def get_key(item):
        return(item[0])
    
    
    def closest_point_problem(points):
        point = []
        heap = []
        pt = namedtuple('pt', 'x y')
        for i in range(len(points)):
            point.append(pt(points[i][0], points[i][1]))
        point = sorted(point, key=get_key)
        visited_index = []
        find_min(0, len(point) - 1, point, heap, visited_index)
        print(hq.heappop(heap))
    
    def find_min(start, end, point, heap, visited_index):
        if len(point[start:end + 1]) & 1:
            mid = start + ((len(point[start:end + 1]) + 1) >> 1)
        else:
            mid = start + (len(point[start:end + 1]) >> 1)
            if start in visited_index:
                start = start + 1
            if end in visited_index:
                end = end - 1
        if len(point[start:end + 1]) > 3:
            if start < mid - 1:
                distance1 = m.sqrt((point[start].x - point[start + 1].x) ** 2 + (point[start].y - point[start + 1].y) ** 2)
                distance2 = m.sqrt((point[mid].x - point[mid - 1].x) ** 2 + (point[mid].y - point[mid - 1].y) ** 2)
                if distance1 < distance2:
                    hq.heappush(heap, (distance1, ((point[start].x, point[start].y), (point[start + 1].x, point[start + 1].y))))
                else:
                    hq.heappush(heap, (distance2, ((point[mid].x, point[mid].y), (point[mid - 1].x, point[mid - 1].y))))
                visited_index.append(start)
                visited_index.append(start + 1)
                visited_index.append(mid)
                visited_index.append(mid - 1)
                find_min(start, mid, point, heap, visited_index)
            if mid + 1 < end:
                distance1 = m.sqrt((point[mid].x - point[mid + 1].x) ** 2 + (point[mid].y - point[mid + 1].y) ** 2)
                distance2 = m.sqrt((point[end].x - point[end - 1].x) ** 2 + (point[end].y - point[end - 1].y) ** 2)
                if distance1 < distance2:
                    hq.heappush(heap, (distance1, ((point[mid].x, point[mid].y), (point[mid + 1].x, point[mid + 1].y))))
                else:
                    hq.heappush(heap, (distance2, ((point[end].x, point[end].y), (point[end - 1].x, point[end - 1].y))))
                visited_index.append(end)
                visited_index.append(end - 1)
                visited_index.append(mid)
                visited_index.append(mid + 1)
                find_min(mid, end, point, heap, visited_index)
    
    x = []
    num_points = 10
    for i in range(num_points):
        x.append((randint(- num_points << 2, num_points << 2), randint(- num_points << 2, num_points << 2)))
    closest_point_problem(x)
    

    :)

    0 讨论(0)
提交回复
热议问题