问题
I want to plot a decision tree of a random forest. So, i create the following code:
clf = RandomForestClassifier(n_estimators=100)
import pydotplus
import six
from sklearn import tree
dotfile = six.StringIO()
i_tree = 0
for tree_in_forest in clf.estimators_:
if (i_tree <1):
tree.export_graphviz(tree_in_forest, out_file=dotfile)
pydotplus.graph_from_dot_data(dotfile.getvalue()).write_png('dtree'+ str(i_tree) +'.png')
i_tree = i_tree + 1
But it doesn't generate anything.. Have you an idea how to plot a decision tree from random forest ?
Thank you,
回答1:
Assuming your Random Forest model is already fitted,
first you should first import the export_graphviz
function:
from sklearn.tree import export_graphviz
In your for cycle you could do the following to generate the dot
file
export_graphviz(tree_in_forest,
feature_names=X.columns,
filled=True,
rounded=True)
The next line generates a png file
os.system('dot -Tpng tree.dot -o tree.png')
回答2:
you can view each tree like this,
i_tree = 0
for tree_in_forest in FT_cls_gini.estimators_:
if (i_tree ==3):
tree.export_graphviz(tree_in_forest, out_file=dotfile)
graph = pydotplus.graph_from_dot_data(dotfile.getvalue())
i_tree = i_tree + 1
Image(graph.create_png())
回答3:
You can draw a single tree:
from sklearn.tree import export_graphviz
from IPython import display
from sklearn.ensemble import RandomForestRegressor
m = RandomForestRegressor(n_estimators=1, max_depth=3, bootstrap=False, n_jobs=-1)
m.fit(X_train, y_train)
str_tree = export_graphviz(m,
out_file=None,
feature_names=X_train.columns, # column names
filled=True,
special_characters=True,
rotate=True,
precision=0.6)
display.display(str_tree)
来源:https://stackoverflow.com/questions/40155128/plot-trees-for-a-random-forest-in-python-with-scikit-learn