问题
Hi I'm currently trying to extract some of the inner node information stored in the constant partying object in R using ctree in partykit but I'm finding navigating the objects a bit difficult, I'm able to display the information on a plot but I'm not sure how to extract the information - I think it requires nodeapply or another function in the partykit?
library(partykit)
irisct <- ctree(Species ~ .,data = iris)
plot(irisct, inner_panel = node_barplot(irisct))
Plot with inner node details
All the information is accessible by the functions to plot, but I'm after a text output similar to: Example output
回答1:
The main trick (as previously pointed out by @G5W) is to take the [id]
subset of the party
object and then extract the data (by either $data
or using the data_party()
function) which contains the response. I would recommend to build a table with absolute frequencies first and then compute the relative and marginal frequencies from that. Using the irisct
object the plain table can be obtained by
tab <- sapply(1:length(irisct), function(id) {
y <- data_party(irisct[id])
y <- y[["(response)"]]
table(y)
})
tab
## [,1] [,2] [,3] [,4] [,5] [,6] [,7]
## setosa 50 50 0 0 0 0 0
## versicolor 50 0 50 49 45 4 1
## virginica 50 0 50 5 1 4 45
Then we can add a little bit of formatting to a nice table
object:
colnames(tab) <- 1:length(irisct)
tab <- as.table(tab)
names(dimnames(tab)) <- c("Species", "Node")
And then use prop.table()
and margin.table()
to compute the frequencies we are interested in. The as.data.frame()
method transform from the table
layout to a "long" data.frame
:
as.data.frame(prop.table(tab, 1))
## Species Node Freq
## 1 setosa 1 0.500000000
## 2 versicolor 1 0.251256281
## 3 virginica 1 0.322580645
## 4 setosa 2 0.500000000
## 5 versicolor 2 0.000000000
## 6 virginica 2 0.000000000
## 7 setosa 3 0.000000000
## 8 versicolor 3 0.251256281
## 9 virginica 3 0.322580645
## 10 setosa 4 0.000000000
## 11 versicolor 4 0.246231156
## 12 virginica 4 0.032258065
## 13 setosa 5 0.000000000
## 14 versicolor 5 0.226130653
## 15 virginica 5 0.006451613
## 16 setosa 6 0.000000000
## 17 versicolor 6 0.020100503
## 18 virginica 6 0.025806452
## 19 setosa 7 0.000000000
## 20 versicolor 7 0.005025126
## 21 virginica 7 0.290322581
as.data.frame(margin.table(tab, 2))
## Node Freq
## 1 1 150
## 2 2 50
## 3 3 100
## 4 4 54
## 5 5 46
## 6 6 8
## 7 7 46
And the split information can be obtained with the (still unexported) .list.rules.party()
function. You just need to ask for all node IDs (the default is to use just the terminal node IDs):
partykit:::.list.rules.party(irisct, i = nodeids(irisct))
## 1
## ""
## 2
## "Petal.Length <= 1.9"
## 3
## "Petal.Length > 1.9"
## 4
## "Petal.Length > 1.9 & Petal.Width <= 1.7"
## 5
## "Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8"
## 6
## "Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length > 4.8"
## 7
## "Petal.Length > 1.9 & Petal.Width > 1.7"
回答2:
Most of the information that you want is accessible without much work. I will show how to get the information, but leave you to format the information into a pretty table.
Notice that your tree structure irisct
is just a list of each of the nodes.
length(irisct)
[1] 7
Each node has a field data
that contains the points that have made it down
this far in the tree, so you can get the number of observations at the node
by counting the rows.
dim(irisct[4]$data)
[1] 54 5
nrow(irisct[4]$data)
[1] 54
Or doing them all at once to get your table 2
NObs = sapply(1:7, function(n) { nrow(irisct[n]$data) })
NObs
[1] 150 50 100 54 46 8 46
The first column of the data at a node is the class (Species), so you can get the count of each class and the probability of each class at a node
table(irisct[4]$data[1])
setosa versicolor virginica
0 49 5
table(irisct[4]$data[1]) / NObs[4]
setosa versicolor virginica
0.00000000 0.90740741 0.09259259
The split information in your table 3 is a bit more awkward. Still, you can get a text version of what you need just by printing out the top level node
irisct[1]
Model formula:
Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
Fitted party:
[1] root
| [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
| [3] Petal.Length > 1.9
| | [4] Petal.Width <= 1.7
| | | [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
| | | [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
| | [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
Number of inner nodes: 3
Number of terminal nodes: 4
To save the output for parsing and display
TreeSplits = capture.output(print(irisct[1]))
来源:https://stackoverflow.com/questions/41968910/r-extracting-inner-node-information-and-splits-from-ctree-partykit