Changing colors for decision tree plot created using export graphviz

后端 未结 1 1682
难免孤独
难免孤独 2020-12-16 21:14

I am using scikit\'s regression tree function and graphviz to generate the wonderful, easy to interpret visuals of some decision trees:

dot_data = tree.expor         


        
1条回答
  •  醉梦人生
    2020-12-16 21:39

    • You can get a list of all the edges via graph.get_edge_list()
    • Each source node should have two target nodes, the one with the lower index is evaluated as True, the higher index as False
    • Colors can be assigned via set_fillcolor()

    import pydotplus
    from sklearn.datasets import load_iris
    from sklearn import tree
    import collections
    
    clf = tree.DecisionTreeClassifier(random_state=42)
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    
    dot_data = tree.export_graphviz(clf,
                                    feature_names=iris.feature_names,
                                    out_file=None,
                                    filled=True,
                                    rounded=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    
    colors = ('brown', 'forestgreen')
    edges = collections.defaultdict(list)
    
    for edge in graph.get_edge_list():
        edges[edge.get_source()].append(int(edge.get_destination()))
    
    for edge in edges:
        edges[edge].sort()    
        for i in range(2):
            dest = graph.get_node(str(edges[edge][i]))[0]
            dest.set_fillcolor(colors[i])
    
    graph.write_png('tree.png')
    

    Also, i've seen some trees where the length of the lines connecting nodes is proportional to the % varriance explained by the split. I'd love to be able to do that too if possible!?

    You could play with set_weight() and set_len() but that's a bit more tricky and needs some fiddling to get it right but here is some code to get you started.

    for edge in edges:
        edges[edge].sort()
        src = graph.get_node(edge)[0]
        total_weight = int(src.get_attributes()['label'].split('samples = ')[1].split('
    ')[0]) for i in range(2): dest = graph.get_node(str(edges[edge][i]))[0] weight = int(dest.get_attributes()['label'].split('samples = ')[1].split('
    ')[0]) graph.get_edge(edge, str(edges[edge][0]))[0].set_weight((1 - weight / total_weight) * 100) graph.get_edge(edge, str(edges[edge][0]))[0].set_len(weight / total_weight) graph.get_edge(edge, str(edges[edge][0]))[0].set_minlen(weight / total_weight)

    0 讨论(0)
提交回复
热议问题