问题
Expanding on a prior question: Changing colors for decision tree plot created using export graphviz
How would I color the nodes of the tree bases on the dominant class (species of iris), instead of a binary distinction? This should require a combination of the iris.target_names, the string describing the class, and iris.target, the class.
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()
colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)
for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))
for edge in edges:
edges[edge].sort()
for i in range(2):
dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])
graph.write_png('tree.png')
回答1:
The code from the example looks so familiar and is therefore easy to modify :)
For each node Graphviz
tells us how many samples from each group we have, i.e. if it is a mixed population or the tree came to a decision. We can extract this info and use to get a color.
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
Alternatively you can map the GraphViz
nodes back to the sklearn
nodes:
values = clf.tree_.value[int(node.get_name())][0]
We only have 3 classes, so each one gets its own color (red, green, blue), mixed populations get mixed colors according to their distribution.
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
We can now see the separation nicely, the greener it gets the more of the 2nd class we have, same for blue and the 3rd class.
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf,
feature_names=iris.feature_names,
out_file=None,
filled=True,
rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
node.set_fillcolor(color)
graph.write_png('colored_tree.png')
A general solution for more than 3 classes which colors only the final nodes .
colors = ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')
for node in nodes:
if node.get_name() not in ('node', 'edge'):
values = clf.tree_.value[int(node.get_name())][0]
#color only nodes where only one class is present
if max(values) == sum(values):
node.set_fillcolor(colors[numpy.argmax(values)])
#mixed nodes get the default color
else:
node.set_fillcolor(colors[-1])
回答2:
Great answers guys. Just to add to @Maximilian Peters's answer. One other thing that one can do identify leaf nodes for specific coloration is to check on the split_criteria(threshold) values. Since leaf nodes don't have child nodes, hence the absence of split criteria as well.
https://github.com/scikit-learn/scikit-learn/blob/a24c8b464d094d2c468a16ea9f8bf8d42d949f84/sklearn/tree/_tree.pyx
TREE_UNDEFINED = -2
thresholds = clf.tree_.threshold
for node in nodes:
if node.get_name() not in ('node', 'edge'):
value = clf.tree_.value[int(node.get_name())][0]
# color only nodes where only one class is present or if it is a leaf
# node
if max(values) == sum(values) or
thresholds[int(node.get_name())] == TREE_UNDEFINED:
node.set_fillcolor(colors[numpy.argmax(value)])
# mixed nodes get the default color
else:
node.set_fillcolor(colors[-1])
Not completely related to the question, but adding some more info in-case it is helpful to others. Continuing on this idea of understanding the decision stumps of a tree-based classifier, Skater has added support to summarize all forms of tree-based models using tree surrogates. Check out the examples here.
https://github.com/datascienceinc/Skater/blob/master/examples/rule_list_notebooks/explanation_using_tree_surrogate.ipynb
来源:https://stackoverflow.com/questions/43214350/color-of-the-node-of-tree-with-graphviz-using-class-names