Proximity Matrix in sklearn.ensemble.RandomForestClassifier

后端 未结 3 1768
盖世英雄少女心
盖世英雄少女心 2020-12-28 20:01

I\'m trying to perform clustering in Python using Random Forests. In the R implementation of Random Forests, there is a flag you can set to get the proximity matrix. I can\'

相关标签:
3条回答
  • 2020-12-28 20:52

    We don't implement proximity matrix in Scikit-Learn (yet).

    However, this could be done by relying on the apply function provided in our implementation of decision trees. That is, for all pairs of samples in your dataset, iterate over the decision trees in the forest (through forest.estimators_) and count the number of times they fall in the same leaf, i.e., the number of times apply give the same node id for both samples in the pair.

    Hope this helps.

    0 讨论(0)
  • 2020-12-28 20:54

    Based on Gilles Louppe answer I have written a function. I don't know if it is effective, but it works. Best regards.

    def proximityMatrix(model, X, normalize=True):      
    
        terminals = model.apply(X)
        nTrees = terminals.shape[1]
    
        a = terminals[:,0]
        proxMat = 1*np.equal.outer(a, a)
    
        for i in range(1, nTrees):
            a = terminals[:,i]
            proxMat += 1*np.equal.outer(a, a)
    
        if normalize:
            proxMat = proxMat / nTrees
    
        return proxMat   
    
    from sklearn.ensemble import  RandomForestClassifier
    from sklearn.datasets import load_breast_cancer
    train = load_breast_cancer()
    
    model = RandomForestClassifier(n_estimators=500, max_features=2, min_samples_leaf=40)
    model.fit(train.data, train.target)
    proximityMatrix(model, train.data, normalize=True)
    ## array([[ 1.   ,  0.414,  0.77 , ...,  0.146,  0.79 ,  0.002],
    ##        [ 0.414,  1.   ,  0.362, ...,  0.334,  0.296,  0.008],
    ##        [ 0.77 ,  0.362,  1.   , ...,  0.218,  0.856,  0.   ],
    ##        ..., 
    ##        [ 0.146,  0.334,  0.218, ...,  1.   ,  0.21 ,  0.028],
    ##        [ 0.79 ,  0.296,  0.856, ...,  0.21 ,  1.   ,  0.   ],
    ##        [ 0.002,  0.008,  0.   , ...,  0.028,  0.   ,  1.   ]])
    
    0 讨论(0)
  • 2020-12-28 21:02

    There is nothing currently implemented for this in python. I took a first try at it here. It would be great if somebody would be interested in adding these methods to scikit.

    0 讨论(0)
提交回复
热议问题