Sklearn kNN usage with a user defined metric

前端 未结 2 1490
轮回少年
轮回少年 2020-12-01 05:35

Currently I\'m doing a project which may require using a kNN algorithm to find the top k nearest neighbors for a given point, say P. im using python, sklearn package to do

相关标签:
2条回答
  • 2020-12-01 06:13

    A small addition to the previous answer. How to use a user defined metric that takes additional arguments.

    >>> def mydist(x, y, **kwargs):
    ...     return np.sum((x-y)**kwargs["metric_params"]["power"])
    ...
    >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
    >>> Y = np.array([-1, -1, -2, 1, 1, 2])
    >>> nbrs = KNeighborsClassifier(n_neighbors=4, algorithm='ball_tree',
    ...            metric=mydist, metric_params={"power": 2})
    >>> nbrs.fit(X, Y)
    KNeighborsClassifier(algorithm='ball_tree', leaf_size=30,                                                                                                                                                          
           metric=<function mydist at 0x7fd259c9cf50>, n_neighbors=4, p=2,
           weights='uniform')
    >>> nbrs.kneighbors(X)
    (array([[  0.,   1.,   5.,   8.],
           [  0.,   1.,   2.,  13.],
           [  0.,   2.,   5.,  25.],
           [  0.,   1.,   5.,   8.],
           [  0.,   1.,   2.,  13.],
           [  0.,   2.,   5.,  25.]]),
     array([[0, 1, 2, 3],
           [1, 0, 2, 3],
           [2, 1, 0, 3],
           [3, 4, 5, 0],
           [4, 3, 5, 0],
           [5, 4, 3, 0]]))
    
    0 讨论(0)
  • 2020-12-01 06:16

    You pass a metric as metric param, and additional metric arguments as keyword paramethers to NN constructor:

    >>> def mydist(x, y):
    ...     return np.sum((x-y)**2)
    ...
    >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
    
    >>> nbrs = NearestNeighbors(n_neighbors=4, algorithm='ball_tree',
    ...            metric='pyfunc', func=mydist)
    >>> nbrs.fit(X)
    NearestNeighbors(algorithm='ball_tree', leaf_size=30, metric='pyfunc',
             n_neighbors=4, radius=1.0)
    >>> nbrs.kneighbors(X)
    (array([[  0.,   1.,   5.,   8.],
           [  0.,   1.,   2.,  13.],
           [  0.,   2.,   5.,  25.],
           [  0.,   1.,   5.,   8.],
           [  0.,   1.,   2.,  13.],
           [  0.,   2.,   5.,  25.]]), array([[0, 1, 2, 3],
           [1, 0, 2, 3],
           [2, 1, 0, 3],
           [3, 4, 5, 0],
           [4, 3, 5, 0],
           [5, 4, 3, 0]]))
    
    0 讨论(0)
提交回复
热议问题