How to use a user defined metric for nearest neighbors in scikit-learn?

前端 未结 1 1406
孤城傲影
孤城傲影 2021-01-20 01:54

I am using scikit-learn 0.18.dev0. I know exactly the same question has been asked before here. I tried the answer presented there, I am getting the following error

1条回答
  •  梦毁少年i
    2021-01-20 02:20

    The proper keyword is metric:

    import numpy as np
    from sklearn.neighbors import NearestNeighbors
    
    def mydist(x, y):
        return np.sum((x-y)**2)
    
    nn = NearestNeighbors(n_neighbors=4, algorithm='ball_tree', metric=myfunc)
    
    X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3,   2]])
    nn.fit(X)
    

    This is also mentioned in the docstring in the development version: https://github.com/scikit-learn/scikit-learn/blob/86b1ba72771718acbd1e07fbdc5caaf65ae65440/sklearn/neighbors/unsupervised.py#L48

    0 讨论(0)
提交回复
热议问题