I want to find the distance of samples to the decision boundary of a trained decision trees classifier in scikit-learn. The features are all numeric and the feature space co
Since there can be multiple decision boundaries around a sample, I'm going to assume distance here refers to distance to nearest decision boundary.
The solution is a recursive tree traversal algorithm. Note that decision tree doesn't allow a sample to be on boundary, like e.g. SVM, each sample in feature space must belong to one of the classes. So here we will keep modifying the sample's feature in small steps, and whenever that leads to a region with a different label (than one originally assigned to the sample by trained classifier), we assume we've reached decision boundary.
In detail, like any recursive algorithm, we have two main cases to consider:
None
.Complete python code:
def f(node,x,orig_label):
global dt,tree
if tree.children_left[node]==tree.children_right[node]: #Meaning node is a leaf
return [x] if dt.predict([x])[0]!=orig_label else [None]
if x[tree.feature[node]]<=tree.threshold[node]:
orig = f(tree.children_left[node],x,orig_label)
xc = x.copy()
xc[tree.feature[node]] = tree.threshold[node] + .01
modif = f(tree.children_right[node],xc,orig_label)
else:
orig = f(tree.children_right[node],x,orig_label)
xc = x.copy()
xc[tree.feature[node]] = tree.threshold[node]
modif = f(tree.children_left[node],xc,orig_label)
return [s for s in orig+modif if s is not None]
This is going to return us a list of samples that lead to leaves with different label. All we need to do now is to take the nearest one:
dt = DecisionTreeClassifier(max_depth=2).fit(X,y)
tree = dt.tree_
res = f(0,x,dt.predict([x])[0]) # 0 is index of root node
ans = np.min([np.linalg.norm(x-n) for n in res])
For illustration:
Blue is the original sample, yellow is the nearest sample "on" decision boundary.
Decision tree does not learn to draw a decision boundary. It tries to split the tree based on the maximum information gain point. For this process, decision tree algorithm uses entropy
or gini
indexes.
Because of this reason, you cannot find the distance between the points and the decision boundary( there is no decision boundary).
If you want you can calculate the distance between the points and the lines that you draw on graphic. So it approximately gives some results.