Getting decision path to a node in sklearn

前端 未结 2 755
轻奢々
轻奢々 2020-12-31 23:04

I wanted the decision path (i.e the set of rules) from the root node to a given node (which I supply) in a decision tree (DecisionTreeClassifier) in scikit-learn. clf.

相关标签:
2条回答
  • 2020-12-31 23:30

    If you supply None to the out_file in export_graphviz, you can get a string representation of the tree.

    from sklearn.datasets import load_iris
    from sklearn import tree
    
    clf = tree.DecisionTreeClassifier()
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    string_data = tree.export_graphviz(clf,
        out_file=None)
    
    print(string_data)
    
    #Output
    digraph Tree {
    node [shape=box] ;
    0 [label="petal length (cm) <= 2.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]\nclass = setosa"] ;
    1 [label="gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]\nclass = setosa"] ;
    0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
    2 [label="petal width (cm) <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]\nclass = versicolor"] ;
    0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
    3 [label="petal length (cm) <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]\nclass = versicolor"] ;
    2 -> 3 ;
    4 [label="petal width (cm) <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]\nclass = versicolor"] ;
    3 -> 4 ;
    5 [label="gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]\nclass = versicolor"] ;
    4 -> 5 ;
    6 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica"] ;
    4 -> 6 ;
    7 [label="petal width (cm) <= 1.55\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]\nclass = virginica"] ;
    3 -> 7 ;
    8 [label="gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]\nclass = virginica"] ;
    7 -> 8 ;
    9 [label="sepal length (cm) <= 6.95\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]\nclass = versicolor"] ;
    7 -> 9 ;
    10 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = versicolor"] ;
    9 -> 10 ;
    11 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = virginica"] ;
    9 -> 11 ;
    12 [label="petal length (cm) <= 4.85\ngini = 0.043\nsamples = 46\nvalue = [0, 1, 45]\nclass = virginica"] ;
    2 -> 12 ;
    13 [label="sepal length (cm) <= 5.95\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]\nclass = virginica"] ;
    12 -> 13 ;
    14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = versicolor"] ;
    13 -> 14 ;
    15 [label="gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]\nclass = virginica"] ;
    13 -> 15 ;
    16 [label="gini = 0.0\nsamples = 43\nvalue = [0, 0, 43]\nclass = virginica"] ;
    12 -> 16 ;
    }
    

    This will have what you want. You can then easily write a program to parse this to handle as you want.

    0 讨论(0)
  • 2020-12-31 23:39

    For the decision rules of the nodes using the iris dataset:

    from sklearn.datasets import load_iris
    from sklearn import tree
    import graphviz 
    
    iris = load_iris()
    clf = tree.DecisionTreeClassifier()
    clf = clf.fit(iris.data, iris.target)
    
    dot_data = tree.export_graphviz(clf, out_file=None, 
                                    feature_names=iris.feature_names,  
                                    class_names=iris.target_names,  
                                    filled=True, rounded=True,  
                                    special_characters=True)  
    graph = graphviz.Source(dot_data)  
    #this will create an iris.pdf file with the rule path
    graph.render("iris")
    


    For sample based paths use:

    import numpy as np
    from sklearn.model_selection import train_test_split
    from sklearn.datasets import load_iris
    from sklearn.tree import DecisionTreeClassifier
    
    iris = load_iris()
    X = iris.data
    y = iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
    
    estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
    estimator.fit(X_train, y_train)
    
    # The decision estimator has an attribute called tree_  which stores the entire
    # tree structure and allows access to low level attributes. The binary tree
    # tree_ is represented as a number of parallel arrays. The i-th element of each
    # array holds information about the node `i`. Node 0 is the tree's root. NOTE:
    # Some of the arrays only apply to either leaves or split nodes, resp. In this
    # case the values of nodes of the other type are arbitrary!
    #
    # Among those arrays, we have:
    #   - left_child, id of the left child of the node
    #   - right_child, id of the right child of the node
    #   - feature, feature used for splitting the node
    #   - threshold, threshold value at the node
    
    n_nodes = estimator.tree_.node_count
    children_left = estimator.tree_.children_left
    children_right = estimator.tree_.children_right
    feature = estimator.tree_.feature
    threshold = estimator.tree_.threshold
    
    # The tree structure can be traversed to compute various properties such
    # as the depth of each node and whether or not it is a leaf.
    node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, -1)]  # seed is the root node id and its parent depth
    while len(stack) > 0:
        node_id, parent_depth = stack.pop()
        node_depth[node_id] = parent_depth + 1
    
        # If we have a test node
        if (children_left[node_id] != children_right[node_id]):
            stack.append((children_left[node_id], parent_depth + 1))
            stack.append((children_right[node_id], parent_depth + 1))
        else:
            is_leaves[node_id] = True
    
    print("The binary tree structure has %s nodes and has "
          "the following tree structure:"
          % n_nodes)
    for i in range(n_nodes):
        if is_leaves[i]:
            print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
        else:
            print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
                  "node %s."
                  % (node_depth[i] * "\t",
                     i,
                     children_left[i],
                     feature[i],
                     threshold[i],
                     children_right[i],
                     ))
    print()
    
    # First let's retrieve the decision path of each sample. The decision_path
    # method allows to retrieve the node indicator functions. A non zero element of
    # indicator matrix at the position (i, j) indicates that the sample i goes
    # through the node j.
    
    node_indicator = estimator.decision_path(X_test)
    
    # Similarly, we can also have the leaves ids reached by each sample.
    
    leave_id = estimator.apply(X_test)
    
    # Now, it's possible to get the tests that were used to predict a sample or
    # a group of samples. First, let's make it for the sample.
    
    # HERE IS WHAT YOU WANT
    sample_id = 0
    node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                        node_indicator.indptr[sample_id + 1]]
    
    print('Rules used to predict sample %s: ' % sample_id)
    for node_id in node_index:
    
        if leave_id[sample_id] == node_id:  # <-- changed != to ==
            #continue # <-- comment out
            print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
    
        else: # < -- added else to iterate through decision nodes
            if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
                threshold_sign = "<="
            else:
                threshold_sign = ">"
    
            print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
                  % (node_id,
                     sample_id,
                     feature[node_id],
                     X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                     threshold_sign,
                     threshold[node_id]))
    

    This will print at the end the following:

    Rules used to predict sample 0: decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011920929) decision id node 2 : (X[0, 2] (= 5.1) > 4.949999809265137) leaf node 4 reached, no decision here


    0 讨论(0)
提交回复
热议问题