Extracting decision rules from GradientBoostingClassifier

前端 未结 1 1725
逝去的感伤
逝去的感伤 2021-01-21 13:19

I have gone through the below questions:

how to extract decision rules of GradientBosstingClassifier

How to extract the decision rules from scikit-learn decision

相关标签:
1条回答
  • 2021-01-21 13:53

    There is no need to use the graphviz export to access the decision tree data. model.estimators_ contains all the individual classifiers that the model consists of. In the case of a GradientBoostingClassifier, this is a 2D numpy array with shape (n_estimators, n_classes), and each item is a DecisionTreeRegressor.

    Each decision tree has a property _tree and Understanding the decision tree structure shows how to get out the nodes, thresholds and children from that object.

    
    import numpy
    import pandas
    from sklearn.ensemble import GradientBoostingClassifier
    
    est = GradientBoostingClassifier(n_estimators=4)
    numpy.random.seed(1)
    est.fit(numpy.random.random((100, 3)), numpy.random.choice([0, 1, 2], size=(100,)))
    print('s', est.estimators_.shape)
    
    n_classes, n_estimators = est.estimators_.shape
    for c in range(n_classes):
        for t in range(n_estimators):
            dtree = est.estimators_[c, t]
            print("class={}, tree={}: {}".format(c, t, dtree.tree_))
    
            rules = pandas.DataFrame({
                'child_left': dtree.tree_.children_left,
                'child_right': dtree.tree_.children_right,
                'feature': dtree.tree_.feature,
                'threshold': dtree.tree_.threshold,
            })
            print(rules)
    

    Outputs something like this for each tree:

    class=0, tree=0: <sklearn.tree._tree.Tree object at 0x7f18a697f370>
       child_left  child_right  feature  threshold
    0           1            2        0   0.020702
    1          -1           -1       -2  -2.000000
    2           3            6        1   0.879058
    3           4            5        1   0.543716
    4          -1           -1       -2  -2.000000
    5          -1           -1       -2  -2.000000
    6           7            8        0   0.292586
    7          -1           -1       -2  -2.000000
    8          -1           -1       -2  -2.000000
    
    0 讨论(0)
提交回复
热议问题