How to get the K most distant points, given their coordinates?

后端 未结 5 2262
无人及你
无人及你 2021-02-20 12:41

We have boring CSV with 10000 rows of ages (float), titles (enum/int), scores (float), ....

  • We have N columns each with int/float values in a table.
5条回答
  •  故里飘歌
    2021-02-20 13:06

    If you're interested in getting the most distant points you can take advantage of all of the methods that were developed for nearest neighbors, you just have to give a different "metric".

    For example, using scikit-learn's nearest neighbors and distance metrics tools you can do something like this

    import numpy as np
    from sklearn.neighbors import BallTree
    from sklearn.neighbors.dist_metrics import PyFuncDistance
    from sklearn.datasets import make_blobs
    from matplotlib import pyplot as plt
    
    
    def inverted_euclidean(x1, x2):
        # You can speed this up using cython like scikit-learn does or numba
        dist = np.sum((x1 - x2) ** 2)
        # We invert the euclidean distance and set nearby points to the biggest possible
        # positive float that isn't inf
        inverted_dist = np.where(dist == 0, np.nextafter(np.inf, 0), 1 / dist)
        return inverted_dist
    
    # Make up some fake data
    n_samples = 100000
    n_features = 200
    X, _ = make_blobs(n_samples=n_samples, centers=3, n_features=n_features, random_state=0)
    
    # We exploit the BallTree algorithm to get the most distant points
    ball_tree = BallTree(X, leaf_size=50, metric=PyFuncDistance(inverted_euclidean))
    
    # Some made up query, you can also provide a stack of points to query against
    test_point = np.zeros((1, n_features))
    distance, distant_points_inds = ball_tree.query(X=test_point, k=10, return_distance=True)
    distant_points = X[distant_points_inds[0]]
    
    # We can try to visualize the query results
    plt.plot(X[:, 0], X[:, 1], ".b", alpha=0.1)
    plt.plot(test_point[:, 0], test_point[:, 1], "*r", markersize=9)
    plt.plot(distant_points[:, 0], distant_points[:, 1], "sg", markersize=5, alpha=0.8)
    plt.show()
    

    Which will plot something like:

    There are many points that you can improve on:

    1. I implemented the inverted_euclidean distance function with numpy, but you can try to do what the folks of scikit-learn do with their distance functions and implement them in cython. You could also try to jit compile them with numba.
    2. Maybe the euclidean distance isn't the metric you would like to use to find the furthest points, so you're free to implement your own or simply roll with what scikit-learn provides.

    The nice thing about using the Ball Tree algorithm (or the KdTree algorithm) is that for each queried point you have to do log(N) comparisons to find the furthest point in the training set. Building the Ball Tree itself, I think also requires log(N) comparison, so in the end if you want to find the k furthest points for every point in the ball tree training set (X), it will have almost O(D N log(N)) complexity (where D is the number of features), which will increase up to O(D N^2) with the increasing k.

提交回复
热议问题