Color of the node of tree with graphviz using class_names

喜你入骨 提交于 2019-12-20 02:15:38

问题


Expanding on a prior question: Changing colors for decision tree plot created using export graphviz

How would I color the nodes of the tree bases on the dominant class (species of iris), instead of a binary distinction? This should require a combination of the iris.target_names, the string describing the class, and iris.target, the class.

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

回答1:


The code from the example looks so familiar and is therefore easy to modify :)

For each node Graphviz tells us how many samples from each group we have, i.e. if it is a mixed population or the tree came to a decision. We can extract this info and use to get a color.

values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]

Alternatively you can map the GraphViz nodes back to the sklearn nodes:

values = clf.tree_.value[int(node.get_name())][0]

We only have 3 classes, so each one gets its own color (red, green, blue), mixed populations get mixed colors according to their distribution.

values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])

We can now see the separation nicely, the greener it gets the more of the 2nd class we have, same for blue and the 3rd class.


import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf,
                                feature_names=iris.feature_names,
                                out_file=None,
                                filled=True,
                                rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()

for node in nodes:
    if node.get_label():
        values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
        values = [int(255 * v / sum(values)) for v in values]
        color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
        node.set_fillcolor(color)

graph.write_png('colored_tree.png')

A general solution for more than 3 classes which colors only the final nodes .

colors =  ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')

for node in nodes:
    if node.get_name() not in ('node', 'edge'):
        values = clf.tree_.value[int(node.get_name())][0]
        #color only nodes where only one class is present
        if max(values) == sum(values):    
            node.set_fillcolor(colors[numpy.argmax(values)])
        #mixed nodes get the default color
        else:
            node.set_fillcolor(colors[-1])




回答2:


Great answers guys. Just to add to @Maximilian Peters's answer. One other thing that one can do identify leaf nodes for specific coloration is to check on the split_criteria(threshold) values. Since leaf nodes don't have child nodes, hence the absence of split criteria as well.

https://github.com/scikit-learn/scikit-learn/blob/a24c8b464d094d2c468a16ea9f8bf8d42d949f84/sklearn/tree/_tree.pyx
TREE_UNDEFINED = -2 
thresholds = clf.tree_.threshold
for node in nodes:
    if node.get_name() not in ('node', 'edge'):
        value = clf.tree_.value[int(node.get_name())][0]
        # color only nodes where only one class is present or if it is a leaf 
        # node
        if max(values) == sum(values) or 
            thresholds[int(node.get_name())] == TREE_UNDEFINED:    
                node.set_fillcolor(colors[numpy.argmax(value)])
        # mixed nodes get the default color
        else:
            node.set_fillcolor(colors[-1])

Not completely related to the question, but adding some more info in-case it is helpful to others. Continuing on this idea of understanding the decision stumps of a tree-based classifier, Skater has added support to summarize all forms of tree-based models using tree surrogates. Check out the examples here.

https://github.com/datascienceinc/Skater/blob/master/examples/rule_list_notebooks/explanation_using_tree_surrogate.ipynb



来源:https://stackoverflow.com/questions/43214350/color-of-the-node-of-tree-with-graphviz-using-class-names

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!