How to get terminal nodes for a new observation from an rpart object?

廉价感情. 提交于 2019-12-19 04:15:10

问题


Say I have

head(kyphosis)
inTrain <- sample(1:nrow(kyphosis), 45, replace = F)
TRAIN_KYPHOSIS <- kyphosis[inTrain,]
TEST_KYPHOSIS <- kyphosis[-inTrain,]

(kyph_tree <- rpart(Number ~ ., data = TRAIN_KYPHOSIS))

How to get the terminal node from the fitted object for each observation in TEST_KYPHOSIS?

How do I get a summary, such as the deviance and the predicted value from the terminal node which each test observation maps to?


回答1:


rpart actually has this functionality but it's not exposed (strangely enough, it's a rather obvious requirement).

predict_nodes <-
    function (object, newdata, na.action = na.pass) {
        where <-
            if (missing(newdata)) 
                object$where
            else {
                if (is.null(attr(newdata, "terms"))) {
                    Terms <- delete.response(object$terms)
                    newdata <- model.frame(Terms, newdata, na.action = na.action, 
                                           xlev = attr(object, "xlevels"))
                    if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                        .checkMFClasses(cl, newdata, TRUE)
                }
                rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata))
            }
        as.integer(row.names(object$frame))[where]
    }

And then:

> predict_nodes(kyph_tree, TEST_KYPHOSIS)
 [1] 5 3 4 3 3 5 5 3 3 3 3 5 5 4 3 5 4 3 3 3 3 4 3 4 4 5 5 3 4 4 3 5 3 5 5 5



回答2:


One option is to convert the rpart object to an object of class party from the partykit package. That provides a general toolkit for dealing with recursive partytions. The conversion is simple:

library("partykit")
(kyph_party <- as.party(kyph_tree))

Model formula:
Number ~ Kyphosis + Age + Start

Fitted party:
[1] root
|   [2] Start >= 15.5: 2.933 (n = 15, err = 10.9)
|   [3] Start < 15.5
|   |   [4] Age >= 112.5: 3.714 (n = 14, err = 18.9)
|   |   [5] Age < 112.5: 5.125 (n = 16, err = 29.8)

Number of inner nodes:    2
Number of terminal nodes: 3

(For exact reproducibility run the code from your question with set.seed(1) prior to running my code.)

For objects of this class there are somewhat more flexible methods for plot(), predict(), fitted(), etc. For example, plot(kyph_party) yields a more informative display than the default plot(kyph_tree). The fitted() method extracts a two-column data.frame with the fitted node numbers and the observed responses on the training data.

kyph_fit <- fitted(kyph_party)
head(kyph_fit, 3)

  (fitted) (response)
1        5          6
2        2          2
3        4          3

With this you can easily compute any quantity you are interested in, e.g., the means, median, or residual sums of squares within each node.

tapply(kyph_fit[,2], kyph_fit[,1], mean)

       2        4        5 
2.933333 3.714286 5.125000 

tapply(kyph_fit[,2], kyph_fit[,1], median)

2 4 5 
3 4 5 

tapply(kyph_fit[,2], kyph_fit[,1], function(x) sum((x - mean(x))^2))

       2        4        5 
10.93333 18.85714 29.75000 

Instead of the simple tapply() you can use any other function of your choice to compute the tables of grouped statistics.

Now to learn which observation from the test data TEST_KYPHOSIS to which node in the tree you can simply use the predict(..., type = "node") method:

kyph_pred <- predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node")
head(kyph_pred)

 2  3  4  6  7 10 
 4  4  5  2  2  5 


来源:https://stackoverflow.com/questions/29304349/how-to-get-terminal-nodes-for-a-new-observation-from-an-rpart-object

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!