R Extracting inner node information and splits from ctree (partykit)

与世无争的帅哥 提交于 2019-12-06 06:52:00

The main trick (as previously pointed out by @G5W) is to take the [id] subset of the party object and then extract the data (by either $data or using the data_party() function) which contains the response. I would recommend to build a table with absolute frequencies first and then compute the relative and marginal frequencies from that. Using the irisct object the plain table can be obtained by

tab <- sapply(1:length(irisct), function(id) {
  y <- data_party(irisct[id])
  y <- y[["(response)"]]
  table(y)
})
tab
##            [,1] [,2] [,3] [,4] [,5] [,6] [,7]
## setosa       50   50    0    0    0    0    0
## versicolor   50    0   50   49   45    4    1
## virginica    50    0   50    5    1    4   45

Then we can add a little bit of formatting to a nice table object:

colnames(tab) <- 1:length(irisct)
tab <- as.table(tab)
names(dimnames(tab)) <- c("Species", "Node")

And then use prop.table() and margin.table() to compute the frequencies we are interested in. The as.data.frame() method transform from the table layout to a "long" data.frame:

as.data.frame(prop.table(tab, 1))
##       Species Node        Freq
## 1      setosa    1 0.500000000
## 2  versicolor    1 0.251256281
## 3   virginica    1 0.322580645
## 4      setosa    2 0.500000000
## 5  versicolor    2 0.000000000
## 6   virginica    2 0.000000000
## 7      setosa    3 0.000000000
## 8  versicolor    3 0.251256281
## 9   virginica    3 0.322580645
## 10     setosa    4 0.000000000
## 11 versicolor    4 0.246231156
## 12  virginica    4 0.032258065
## 13     setosa    5 0.000000000
## 14 versicolor    5 0.226130653
## 15  virginica    5 0.006451613
## 16     setosa    6 0.000000000
## 17 versicolor    6 0.020100503
## 18  virginica    6 0.025806452
## 19     setosa    7 0.000000000
## 20 versicolor    7 0.005025126
## 21  virginica    7 0.290322581

as.data.frame(margin.table(tab, 2))
##   Node Freq
## 1    1  150
## 2    2   50
## 3    3  100
## 4    4   54
## 5    5   46
## 6    6    8
## 7    7   46

And the split information can be obtained with the (still unexported) .list.rules.party() function. You just need to ask for all node IDs (the default is to use just the terminal node IDs):

partykit:::.list.rules.party(irisct, i = nodeids(irisct))
##                                                               1 
##                                                              "" 
##                                                               2 
##                                           "Petal.Length <= 1.9" 
##                                                               3 
##                                            "Petal.Length > 1.9" 
##                                                               4 
##                       "Petal.Length > 1.9 & Petal.Width <= 1.7" 
##                                                               5 
## "Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8" 
##                                                               6 
##  "Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length > 4.8" 
##                                                               7 
##                        "Petal.Length > 1.9 & Petal.Width > 1.7" 

Most of the information that you want is accessible without much work. I will show how to get the information, but leave you to format the information into a pretty table.

Notice that your tree structure irisct is just a list of each of the nodes.

length(irisct)
[1] 7

Each node has a field data that contains the points that have made it down this far in the tree, so you can get the number of observations at the node by counting the rows.

dim(irisct[4]$data)
[1] 54  5
nrow(irisct[4]$data)
[1] 54

Or doing them all at once to get your table 2

NObs = sapply(1:7, function(n) { nrow(irisct[n]$data) })
NObs
[1] 150  50 100  54  46   8  46

The first column of the data at a node is the class (Species), so you can get the count of each class and the probability of each class at a node

table(irisct[4]$data[1])
setosa versicolor  virginica 
     0         49          5 
table(irisct[4]$data[1]) / NObs[4]
setosa versicolor  virginica 
0.00000000 0.90740741 0.09259259 

The split information in your table 3 is a bit more awkward. Still, you can get a text version of what you need just by printing out the top level node

irisct[1]
Model formula:
Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width
Fitted party:
[1] root
|   [2] Petal.Length <= 1.9: setosa (n = 50, err = 0.0%)
|   [3] Petal.Length > 1.9
|   |   [4] Petal.Width <= 1.7
|   |   |   [5] Petal.Length <= 4.8: versicolor (n = 46, err = 2.2%)
|   |   |   [6] Petal.Length > 4.8: versicolor (n = 8, err = 50.0%)
|   |   [7] Petal.Width > 1.7: virginica (n = 46, err = 2.2%)
Number of inner nodes:    3
Number of terminal nodes: 4

To save the output for parsing and display

TreeSplits = capture.output(print(irisct[1]))
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!