We have boring CSV with 10000 rows of ages (float), titles (enum/int), scores (float), ...
.
If you're interested in getting the most distant points you can take advantage of all of the methods that were developed for nearest neighbors, you just have to give a different "metric".
For example, using scikit-learn
's nearest neighbors and distance metrics tools you can do something like this
import numpy as np
from sklearn.neighbors import BallTree
from sklearn.neighbors.dist_metrics import PyFuncDistance
from sklearn.datasets import make_blobs
from matplotlib import pyplot as plt
def inverted_euclidean(x1, x2):
# You can speed this up using cython like scikit-learn does or numba
dist = np.sum((x1 - x2) ** 2)
# We invert the euclidean distance and set nearby points to the biggest possible
# positive float that isn't inf
inverted_dist = np.where(dist == 0, np.nextafter(np.inf, 0), 1 / dist)
return inverted_dist
# Make up some fake data
n_samples = 100000
n_features = 200
X, _ = make_blobs(n_samples=n_samples, centers=3, n_features=n_features, random_state=0)
# We exploit the BallTree algorithm to get the most distant points
ball_tree = BallTree(X, leaf_size=50, metric=PyFuncDistance(inverted_euclidean))
# Some made up query, you can also provide a stack of points to query against
test_point = np.zeros((1, n_features))
distance, distant_points_inds = ball_tree.query(X=test_point, k=10, return_distance=True)
distant_points = X[distant_points_inds[0]]
# We can try to visualize the query results
plt.plot(X[:, 0], X[:, 1], ".b", alpha=0.1)
plt.plot(test_point[:, 0], test_point[:, 1], "*r", markersize=9)
plt.plot(distant_points[:, 0], distant_points[:, 1], "sg", markersize=5, alpha=0.8)
plt.show()
Which will plot something like:
There are many points that you can improve on:
inverted_euclidean
distance function with numpy, but you can try to do what the folks of scikit-learn do with their distance functions and implement them in cython. You could also try to jit compile them with numba.The nice thing about using the Ball Tree algorithm (or the KdTree algorithm) is that for each queried point you have to do log(N)
comparisons to find the furthest point in the training set. Building the Ball Tree itself, I think also requires log(N)
comparison, so in the end if you want to find the k furthest points for every point in the ball tree training set (X
), it will have almost O(D N log(N))
complexity (where D
is the number of features), which will increase up to O(D N^2)
with the increasing k
.