问题
I have a dependent variable to classify by a decision tree. It's composed by three categories of frequences: 738 (19%), 426 (15%) and 1800 (66%). As you imagine the predicted category is always the third one, but the purpose of the tree is descriptive so it does not actually matter.
The thing is, when plotting a tree by the ctree()
function (package partykit
) the terminal nodes display histograms showing the probability of occurrence of the three classes. I need to modify this output: I would like to obtain the proportions of occurrence of each class within the terminal node with respect to the class' absolute frequency.
For example, which percentage of the 738 participants in class1 belongs to a certain terminal node? Each terminal node would display this values for all the three classes that compose the dependent variable.
Bellow a plot of the tree, which by default reports the prevalence of each class within the terminal nodes.
回答1:
You can always define your own panel function to draw what goes into each terminal panel window. If you know a little bit about grid
graphics and you look at how the current terminal panel functions are defined you will see how this works.
One panel function that ought to do what you want is node_terminal()
in the partykit
package (the much improved re-implementation of the old party
package). However, because ctree()
does not store its predictions in each terminal node, the node_terminal()
function cannot do this out of the box at the moment. I'll try to improve the implementation in future versions so that this can be facilitated. Below is a somewhat involved example that should do what you want, I hope.
First, we fit a classification tree using the iris
data (for a simple reproducible example):
library("partykit")
(ct <- ctree(Species ~ ., data = iris))
## Model formula:
## Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
##
## Fitted party:
## [1] root
## | [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
## | [3] Petal.Length > 1.9
## | | [4] Petal.Width <= 1.7
## | | | [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
## | | | [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
## | | [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
##
## Number of inner nodes: 3
## Number of terminal nodes: 4
Then we compute a table of predicted probabilities for each terminal node:
(pred <- aggregate(predict(ct, type = "prob"),
list(predict(ct, type = "node")), FUN = mean))
## Group.1 setosa versicolor virginica
## 1 2 1 0.00000000 0.00000000
## 2 5 0 0.97826087 0.02173913
## 3 6 0 0.50000000 0.50000000
## 4 7 0 0.02173913 0.97826087
Then comes the not so obvious part: We want to include these predicted probabilities in the terminal nodes of the tree itself. For this, we coerce the recursive node structure to a flat list, insert the predictions (suitably formatted), and convert the list back to the node structure:
ct_node <- as.list(ct$node)
for(i in 1:nrow(pred)) {
ct_node[[pred[i,1]]]$info$prediction <- paste(
format(names(pred)[-1]),
format(round(pred[i, -1], digits = 3), nsmall = 3)
)
}
ct$node <- as.partynode(ct_node)
Then, we can easily draw a picture of the tree with the node_terminal
panel function and inserting our pre-formatted predictions:
plot(ct, terminal_panel = node_terminal, tp_args = list(
FUN = function(node) c("Predictions", node$prediction)))
EDIT: The coercing back and forth between a list
and a party
is actually already implemented in the package...I just forgot about it ;-) If you do
st <- as.simpleparty(ct)
then the resulting party
has in each node more detailed information about the predictions etc. For example, the $distribution
then contains the absolute frequencies for each response level. This can easily be formatted as before
pred <- function(i) {
tab <- i$distribution
tab <- round(prop.table(tab), 3)
tab <- paste0(names(tab), ":", format(tab, nsmall = 3))
c("Predictions", tab)
}
And this can be passed to node_terminal
to essentially create the plot above. You might want to change drop = FALSE
to drop = TRUE
if you want all terminal nodes to be displayed in the bottom row.
plot(st, terminal_panel = node_terminal, tp_args = list(FUN = pred))
来源:https://stackoverflow.com/questions/30644908/modifying-terminal-node-in-ctree-partykit-package