Retrieve Decision Boundary Lines (x,y coordinate format) from SKlearn Decision Tree

后端 未结 3 1565
青春惊慌失措
青春惊慌失措 2021-01-13 08:35

I am trying to create a surface plot on an external visualization platform. I\'m working with the iris data set that is featured on the sklearn decision tree documentation p

3条回答
  •  走了就别回头了
    2021-01-13 09:22

    Decision trees do not have very nice boundaries. They have multiple boundaries that hierarchically split the feature space into rectangular regions.

    In my implementation of Node Harvest I wrote functions that parse scikit's decision trees and extract the decision regions. For this answer I modified parts of that code to return a list of rectangles that correspond to a trees decision regions. It should be easy to draw these rectangles with any plotting library. Here is an example using matplotlib:

    n = 100
    np.random.seed(42)
    x = np.concatenate([np.random.randn(n, 2) + 1, np.random.randn(n, 2) - 1])
    y = ['b'] * n + ['r'] * n
    plt.scatter(x[:, 0], x[:, 1], c=y)
    
    dtc = DecisionTreeClassifier().fit(x, y)
    rectangles = decision_areas(dtc, [-3, 3, -3, 3])
    plot_areas(rectangles)
    plt.xlim(-3, 3)
    plt.ylim(-3, 3)
    

    Wherever regions of different color meet there is a decision boundary. I imagine it would be possible with moderate effort to extract just these boundary lines but I'll leave that to anyone who is interested.

    rectangles is a numpy array. Each row corresponds to one rectangle and the columns are [left, right, top, bottom, class].


    Update: Application to the Iris data set

    The Iris data set contains three classes instead of 2, like in the example. So we have to add another color to the plot_areas function: color = ['b', 'r', 'g'][int(rect[4])]. Furthermore, the data set is 4-dimensional (it contains four features) but we can only plot two features in 2D. We need to chose which features to plot and tell the decision_area function. The function takes two arguments x and y - these are the features that go on the x and y axis, respectively. The default is x=0, y=1 which works with any data set that has more than one feature. However, in the Iris data set the first dimension is not very interesting so we will use a different setting.

    The function decision_areas also does not know about the extent of the data set. Often the decision tree has open decision ranges that extend toward infinity (e.g. Whenever sepal length is less than xyz it's class B). In this case we need to artificially narrow down the range for plotting. I chose -3..3 for the example data set but for the iris data set other ranges are appropriate (there are never negative values, some features extend beyond 3).

    Here we plot the decision regions over the two last features in a range of 0..7 and 0..5:

    from sklearn.datasets import load_iris
    data = load_iris()
    x = data.data
    y = data.target
    dtc = DecisionTreeClassifier().fit(x, y)
    rectangles = decision_areas(dtc, [0, 7, 0, 5], x=2, y=3)
    plt.scatter(x[:, 2], x[:, 3], c=y)
    plot_areas(rectangles)
    

    Note how there is a weird overlap of the red and green areas in the top left. This happens because the tree makes decisions in four dimensions but we can show only two. There is not really a clean way around this. A high dimensional classifier often has no nice decision boundaries in low-dimensional space.

    So if you are more interested in the classifier that is what you get. You can generate different views along various combinations of dimensions but there are limits to the usefulness of the representation.

    However, if you are more interested in the data than in the classifier you can restrict the dimensionality before fitting. In that case the classifier only makes decisions in the 2-dimensional space and we can plot nice decision regions:

    from sklearn.datasets import load_iris
    data = load_iris()
    x = data.data[:, [2, 3]]
    y = data.target
    dtc = DecisionTreeClassifier().fit(x, y)
    rectangles = decision_areas(dtc, [0, 7, 0, 3], x=0, y=1)
    plt.scatter(x[:, 0], x[:, 1], c=y)
    plot_areas(rectangles)
    


    Finally, here is the implementation:

    import numpy as np
    from collections import deque
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.tree import _tree as ctree
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    
    
    class AABB:
        """Axis-aligned bounding box"""
        def __init__(self, n_features):
            self.limits = np.array([[-np.inf, np.inf]] * n_features)
    
        def split(self, f, v):
            left = AABB(self.limits.shape[0])
            right = AABB(self.limits.shape[0])
            left.limits = self.limits.copy()
            right.limits = self.limits.copy()
    
            left.limits[f, 1] = v
            right.limits[f, 0] = v
    
            return left, right
    
    
    def tree_bounds(tree, n_features=None):
        """Compute final decision rule for each node in tree"""
        if n_features is None:
            n_features = np.max(tree.feature) + 1
        aabbs = [AABB(n_features) for _ in range(tree.node_count)]
        queue = deque([0])
        while queue:
            i = queue.pop()
            l = tree.children_left[i]
            r = tree.children_right[i]
            if l != ctree.TREE_LEAF:
                aabbs[l], aabbs[r] = aabbs[i].split(tree.feature[i], tree.threshold[i])
                queue.extend([l, r])
        return aabbs
    
    
    def decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
        """ Extract decision areas.
    
        tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
        maxrange: values to insert for [left, right, top, bottom] if the interval is open (+/-inf) 
        x: index of the feature that goes on the x axis
        y: index of the feature that goes on the y axis
        n_features: override autodetection of number of features
        """
        tree = tree_classifier.tree_
        aabbs = tree_bounds(tree, n_features)
    
        rectangles = []
        for i in range(len(aabbs)):
            if tree.children_left[i] != ctree.TREE_LEAF:
                continue
            l = aabbs[i].limits
            r = [l[x, 0], l[x, 1], l[y, 0], l[y, 1], np.argmax(tree.value[i])]
            rectangles.append(r)
        rectangles = np.array(rectangles)
        rectangles[:, [0, 2]] = np.maximum(rectangles[:, [0, 2]], maxrange[0::2])
        rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
        return rectangles
    
    def plot_areas(rectangles):
        for rect in rectangles:
            color = ['b', 'r'][int(rect[4])]
            print(rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1])
            rp = Rectangle([rect[0], rect[2]], 
                           rect[1] - rect[0], 
                           rect[3] - rect[2], color=color, alpha=0.3)
            plt.gca().add_artist(rp)
    

提交回复
热议问题