cforest prints empty tree

前端 未结 2 1302
刺人心
刺人心 2021-01-03 04:10

I\'m trying to use cforest function(R, party package).

This\'s what I do to construct forest:

library(\"party\")
set.seed(42)
readingSkills.cf <-          


        
2条回答
  •  孤城傲影
    2021-01-03 04:32

    Short answer: the case weights weights in each node are NULL, i.e. not stored. The prettytree function outputs weights = 0, since sum(NULL) equals 0 in R.


    Consider the following ctree example:

    library("party")
    x <- ctree(Species ~ ., data=iris)
    plot(x, type="simple")
    

    ctree plot

    For the resulting object x (class BinaryTree) the case weights are stored in each node:

    R> sum(x@tree$left$weights)
    [1] 50
    R> sum(x@tree$right$weights)
    [1] 100
    R> sum(x@tree$right$left$weights)
    [1] 54
    R> sum(x@tree$right$right$weights)
    [1] 46
    

    Now lets take a closer look at cforest:

    y <- cforest(Species ~ ., data=iris, control=cforest_control(mtry=2))
    tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
    plot(new("BinaryTree", tree=tr, data=y@data, responses=y@responses))
    

    cforest tree

    The case weights are not stored in the tree ensemble, which can be seen by the following:

    fixInNamespace("print.TerminalNode", "party")
    

    change the print method to

    function (x, n = 1, ...)·                                                     
    {                                                                             
        print(names(x))                                                           
        print(x$weights)                                                          
        cat(paste(paste(rep(" ", n - 1), collapse = ""), x$nodeID,·               
            ")* ", sep = "", collapse = ""), "weights =", sum(x$weights),·        
            "\n")                                                                 
    } 
    

    Now we can observe that weights is NULL in every node:

    R> tr
    1) Petal.Width <= 0.4; criterion = 10.641, statistic = 10.641
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
      2)*  weights = 0 
    1) Petal.Width > 0.4
      3) Petal.Width <= 1.6; criterion = 8.629, statistic = 8.629
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
        4)*  weights = 0 
      3) Petal.Width > 1.6
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
        5)*  weights = 0 
    

    Update this is a hack to display the sums of the case weights:

    update_tree <- function(x) {
      if(!x$terminal) {
        x$left <- update_tree(x$left)
        x$right <- update_tree(x$right)
      } else {
        x$weights <- x[[9]]
        x$weights_ <- x[[9]]
      }
      x
    }
    tr_weights <- update_tree(tr)
    plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))
    

    cforest tree with case weights

提交回复
热议问题