Nearest Neighbor Search in Python without k-d tree

后端 未结 4 822
余生分开走
余生分开走 2021-02-08 08:16

I\'m beginning to learn Python coming from a C++ background. What I am looking for is a quick and easy way to find the closest (nearest neighbor) of some multidimensional query

4条回答
  •  情歌与酒
    2021-02-08 08:52

    For faster search and support for dynamic item insertion, you could use a binary tree for 2D items where greater and less than operator is defined by distance to a reference point (0,0).

    def dist(x1,x2):
        return np.sqrt( (float(x1[0])-float(x2[0]))**2 +(float(x1[1])-float(x2[1]))**2 )
    
    class Node(object):
    
        def __init__(self, item=None,):
            self.item = item
            self.left = None
            self.right = None
    
        def __repr__(self):
            return '{}'.format(self.item)
    
        def _add(self, value, center):
            new_node = Node(value)
            if not self.item:
                self.item = new_node        
            else:
            vdist = dist(value,center)
            idist = dist(self.item,center)
                if vdist > idist:
                    self.right = self.right and self.right._add(value, center) or new_node
                elif vdist < idist:
                    self.left = self.left and self.left._add(value, center) or new_node
                else:
                    print("BSTs do not support repeated items.")
    
            return self # this is necessary!!!
    
        def _isLeaf(self):
            return not self.right and not self.left
    
    class BSTC(object):
    
        def __init__(self, center=[0.0,0.0]):
            self.root = None
        self.count = 0
        self.center = center
    
        def add(self, value):
            if not self.root:
                self.root = Node(value)
            else:
                self.root._add(value,self.center)
        self.count += 1
    
        def __len__(self): return self.count
    
        def closest(self, target):
                gap = float("inf")
                closest = float("inf")
                curr = self.root
                while curr:
                    if dist(curr.item,target) < gap:
                        gap = dist(curr.item, target)
                        closest = curr
                    if target == curr.item:
                        break
                    elif dist(target,self.center) < dist(curr.item,self.center):
                        curr = curr.left
                    else:
                        curr = curr.right
                return closest.item, gap
    
    
    import util
    
    bst = util.BSTC()
    print len(bst)
    
    arr = [(23.2323,34.34535),(23.23,36.34535),(53.23,34.34535),(66.6666,11.11111)]
    for i in range(len(arr)): bst.add(arr[i])
    
    f = (11.111,22.2222)
    print bst.closest(f)
    print map(lambda x: util.dist(f,x), arr)
    

提交回复
热议问题