ctree() - How to get the list of splitting conditions for each terminal node?

旧城冷巷雨未停 提交于 2019-12-17 18:33:39

问题


I have an output from ctree() (party package) that looks like the following. How do I get the list of splitting conditions for each terminal node, like like sns <= 0, dta <= 1; sns <= 0, dta > 1 and so on?

1) sns <= 0; criterion = 1, statistic = 14655.021
  2) dta <= 1; criterion = 1, statistic = 3286.389
   3)*  weights = 153682 
  2) dta > 1
   4)*  weights = 289415 
1) sns > 0
  5) dta <= 2; criterion = 1, statistic = 1882.439
   6)*  weights = 245457 
  5) dta > 2
   7) dta <= 6; criterion = 1, statistic = 1170.813
     8)*  weights = 328582 
   7) dta > 6

Thanks


回答1:


This function should do the job

 CtreePathFunc <- function (ct, data) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){
  # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
  NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])


  # Getting the weigths for that node
  NodeWeights <- nodes(ct, Node)[[1]]$weights


  # Finding the path
  Path <- NULL
  for (i in NonTerminalNodes){
    if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
  }

  # Finding the splitting creteria for that path
  Path2 <- SB <- NULL

  for(i in 1:length(Path)){
    if(i == length(Path)) {
      n <- nodes(ct, Node)[[1]]
    } else {n <- nodes(ct, Path[i + 1])[[1]]}

    if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
      SB <- "<="
    } else {SB <- ">"}
    Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
                                 SB,
                                 as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
                   collapse = ", ")
  }

  # Output
  ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }
  return(ResulTable)
}

Testing

library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result 

##   Node                               Path
## 1    5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2    3            Temp <= 82, Wind <= 6.9
## 3    6  Temp <= 82, Wind > 6.9, Temp > 77
## 4    9             Temp > 82, Wind > 10.3
## 5    8            Temp > 82, Wind <= 10.3



回答2:


If you use the new recommended partykit implementation of ctree() rather than the old party package, then you can use the function .list.rules.party(). This is not yet officially exported, yet, but can be leveraged to extract the desired information.

library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq)
partykit:::.list.rules.party(ct)
##                                      3                                      5 
##             "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77" 
##                                      6                                      8 
##  "Temp <= 82 & Wind > 6.9 & Temp > 77"             "Temp > 82 & Wind <= 10.3" 
##                                      9 
##              "Temp > 82 & Wind > 10.3" 



回答3:


Due I needed this function but for categorical data, I make, more or less answering the question @JoãoDaniel (I've only tested with categorical predictor variables), the next functions:

# returns string w/o leading or trailing whitespace
# http://stackoverflow.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r
trim <- function (x) gsub("^\\s+|\\s+$", "", x)
getVariable <- function (x) sub("(.*?)[[:space:]].*", "\\1", x)
getSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\2", x)

getReglaFinal = function(elemento) {        
    x = as.data.frame(strsplit(as.character(elemento),";"))
    Regla = apply(x,1, trim)
    Regla = data.frame(Regla)
    indice = as.numeric(rownames(Regla))
    variable = apply(Regla,1, getVariable)
    simbolo = apply(Regla,1, getSimbolo)

    ReglaRaw = data.frame(Regla,indice,variable,simbolo)
    cols <- c( 'variable' , 'simbolo' )
    ReglaRaw$tipo_corte <- apply(  ReglaRaw[ , cols ] ,1 , paste , collapse = "" )
    #print(ReglaRaw)
    cortes = unique(ReglaRaw$tipo_corte)
    #print(cortes)
    ReglaFinal = ""
    for(i in 1:length(cortes)){
        #print("------------------------------------")
        #print(cortes[i])
        #print("ReglaRaw econtrada")
        #print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
        maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
        #print(maximo)
        tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])
        if(ReglaFinal==""){
            ReglaFinal = tmp
        }else{
            ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")
        }
    }
    return(ReglaFinal)
}#getReglaFinal

CtreePathFuncAllCat <- function (ct) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){

    # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
    NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])

    # Getting the weigths for that node
    NodeWeights <- nodes(ct, Node)[[1]]$weights

    # Finding the path
    Path <- NULL
    for (i in NonTerminalNodes){
        if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
    }

    # Finding the splitting creteria for that path
    Path2 <- SB <- NULL

    variablesNombres <- array()
    variablesPuntos <- list()

    for(i in 1:length(Path)){
        n <- nodes(ct, Path[i])[[1]]

        if(i == length(Path)) {
            nextNodeID = Node
        } else {
            nextNodeID = Path[i+1]
        }       

        vec_puntos  = as.vector(n[[5]]$splitpoint)
        vec_nombre  = n[[5]]$variableName
        vec_niveles = attr(n[[5]]$splitpoint,"levels")

        index = 0

        if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){
            index = vec_puntos
            vec_puntos = vector(length=length(vec_niveles))
            vec_puntos[index] = TRUE
        }

        if(length(vec_niveles)==0){
            index = vec_puntos
            vec_puntos = n[[5]]$splitpoint
        }

        if(index==0){
            if(nextNodeID==n$right$nodeID){
                vec_puntos = !vec_puntos
            }else{
                vec_puntos = !!vec_puntos
            }
            if(i != 1) {
                for(j in 1:(length(Path)-1)){
                    if(length(variablesNombres)>=j){
                        if( variablesNombres[j]==vec_nombre){
                            vec_puntos = vec_puntos*variablesPuntos[[j]]
                        }
                    }
                }
                vec_puntos = vec_puntos==1
            }   
            SB = "="
        }else{
            if(nextNodeID==n$right$nodeID){
                SB = ">"
            }else{
                SB = "<="
            }

        }

        variablesPuntos[[i]] = vec_puntos       
        variablesNombres[i] = vec_nombre

        if(length(vec_niveles)==0){
            descripcion = vec_puntos
        }else{
            descripcion = paste(vec_niveles[vec_puntos],collapse=", ")
        }
        Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ")
                        ),
                       collapse = "; ")
    }

    # Output
    ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }

    we = weights(ct)
    c0 = as.matrix(where(ct))
    c3 = sapply(we, function(w) sum(w))
    c3 = as.matrix(unique(cbind(c0,c3)))
    Counts = as.matrix(c3[,2])
    c2 = drop(Predict(ct))
    Means = as.matrix(unique(c2))

    ResulTable = data.frame(ResulTable,Means,Counts)
    ResulTable  = ResulTable[ order(ResulTable$Means) ,]

    ResulTable$TruePath =  apply(as.data.frame(ResulTable$Path),1, getReglaFinal)

    ResulTable2 = ResulTable

    ResulTable2$SQL <- paste("WHEN ",gsub("\\'([-+]?([0-9]*\\.[0-9]+|[0-9]+))\\'", "\\1",gsub("\\, ", "','", gsub(" \\}", "')", gsub("\\{ ", "('", gsub("\\;", " AND ", ResulTable2$TruePath)))))," THEN ")

    cols <- c( 'SQL' , 'Node' )
    ResulTable2$SQL <- apply(  ResulTable2[ , cols ] ,1 , paste , collapse = "'Nodo " )

    ResulTable2$SQL <- gsub("THEN'", "THEN '", gsub(" '", "'",  paste(ResulTable2$SQL,"'")))

    ResultadoFinal = list()

    ResultadoFinal$PreTable = ResulTable
    ResultadoFinal$Table = ResulTable
    ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath
    ResultadoFinal$Table$TruePath = NULL
    ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")

    return(ResultadoFinal)
}#CtreePathFuncAllCat

Here is a test:

library(party)
#With ordered factors
TreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)
Result2 <- CtreePathFuncAllCat(TreeModel1)
Result2
##$PreTable
##  Node                                                Path    Means Counts
##3    7    DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114
##2    6   DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175
##1    4  DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105
##4    3 DECT <= { Somewhat likely }; DECT <= { Not likely } 9.833333     18
##                                          TruePath
##3   DECT > { Somewhat likely }; SYMPT > { Disagree }
##2  DECT > { Somewhat likely }; SYMPT <= { Disagree }
##1 DECT <= { Somewhat likely }; DECT > { Not likely }
##4                             DECT <= { Not likely }
##
##$Table
##  Node                                               Path    Means Counts
##3    7   DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114
##2    6  DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175
##1    4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105
##4    3                             DECT <= { Not likely } 9.833333     18
##
##$SQL
##[1] " CASE  WHEN  DECT > ('Somewhat likely') AND  SYMPT > ('Disagree')  THEN 'Nodo 7' WHEN  DECT > ('Somewhat likely') AND  SYMPT <= ('Disagree')  THEN 'Nodo 6' WHEN  DECT <= ('Somewhat likely') AND  DECT > ('Not likely')  THEN 'Nodo 4' WHEN  DECT <= ('Not likely')  THEN 'Nodo 3'  END "


#With unordered factors
TreeModel2 = ctree(count~spray, data = InsectSprays)
plot(TreeModel2, type="simple")
Result2 <- CtreePathFuncAllCat(TreeModel2)
Result2
##$PreTable
##Node                                  Path     Means Counts            TruePath
##2    5 spray = { C, D, E }; spray = { C, E }  2.791667     24    spray = { C, E }
##3    4    spray = { C, D, E }; spray = { D }  4.916667     12       spray = { D }
##1    2                   spray = { A, B, F } 15.500000     36 spray = { A, B, F }
##
##$Table
##Node                Path     Means Counts
##2    5    spray = { C, E }  2.791667     24
##3    4       spray = { D }  4.916667     12
##1    2 spray = { A, B, F } 15.500000     36
##
##$SQL
##[1] " CASE  WHEN  spray = ('C','E')  THEN 'Nodo 5' WHEN  spray = ('D')  THEN 'Nodo 4' WHEN  spray = ('A','B','F')  THEN 'Nodo 2'  END "

#With continuous variables
airq <- subset(airquality, !is.na(Ozone))
TreeModel3 <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result2 <- CtreePathFuncAllCat(TreeModel3)
Result2
##$PreTable
##  Node                                           Path    Means Counts
##1    5 Temp <= { 82 }; Wind > { 6.9 }; Temp <= { 77 } 18.47917     48
##3    6  Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21
##4    9                 Temp > { 82 }; Wind > { 10.3 } 48.71429      7
##2    3                Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10
##5    8                Temp > { 82 }; Wind <= { 10.3 } 81.63333     30
##                                     TruePath
##1                Temp <= { 77 }; Wind > { 6.9 }
##3 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 }
##4                Temp > { 82 }; Wind > { 10.3 }
##2               Temp <= { 82 }; Wind <= { 6.9 }
##5               Temp > { 82 }; Wind <= { 10.3 }
##
##$Table
##  Node                                          Path    Means Counts
##1    5                Temp <= { 77 }; Wind > { 6.9 } 18.47917     48
##3    6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21
##4    9                Temp > { 82 }; Wind > { 10.3 } 48.71429      7
##2    3               Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10
##5    8               Temp > { 82 }; Wind <= { 10.3 } 81.63333     30
##
##$SQL
##[1] " CASE  WHEN  Temp <= (77) AND  Wind > (6.9)  THEN 'Nodo 5' WHEN  Temp <= (82) AND  Wind > (6.9) AND  Temp > (77)  THEN 'Nodo 6' WHEN  Temp > (82) AND  Wind > (10.3)  THEN 'Nodo 9' WHEN  Temp <= (82) AND  Wind <= (6.9)  THEN 'Nodo 3' WHEN  Temp > (82) AND  Wind <= (10.3)  THEN 'Nodo 8'  END "

Update! Now the function supports mix of categorical and numerical variables!




回答4:


The CtreePathFunc function rewritten in more of a Hadley-verse (and I think more comprehensible) way. Also handling categorical variables.

library(magrittr)
readSplitter <- function(nodeSplit){
  splitPoint <- nodeSplit$splitpoint
  if("levels" %>% is_in(splitPoint %>% attributes %>% names)){
    splitPoint %>% attr("levels") %>% .[splitPoint]
  }else{
    splitPoint %>% as.numeric
  }
}

hasWeigths <- function(ct, path, terminalNode, pathNumber){
  ct %>%
    nodes(pathNumber %>% equals(path %>% length) %>% ifelse(terminalNode, path[pathNumber + 1]) ) %>%
    .[[1]] %>% use_series("weights") %>% as.logical %>% which
}

dataFilter <- function(ct, dts, path, terminalNode, pathNumber){
  whichWeights <- hasWeigths(ct, path, terminalNode, pathNumber)
  nodes(ct, path[pathNumber])[[1]][[5]] %>%
    buildDataFilter(dts, whichWeights)
}

buildDataFilter <- function(nodeSplit, ...) UseMethod("buildDataFilter")

buildDataFilter.nominalSplit <-
  function(nodeSplit, dts, whichWeights){
    varName <- nodeSplit$variableName
    includedLevels <- dts[ whichWeights
                          ,varName] %>% unique
    paste( varName, "=="
          ,includedLevels %>% paste(collapse = ", ") %>% paste0("{", ., "}"))
  }

buildDataFilter.orderedSplit <-
  function(nodeSplit, dts, whichWeights){
    varName <- nodeSplit$variableName
    splitter <- nodeSplit %>% readSplitter

    dts[ whichWeights
        ,varName] %>%
          is_weakly_less_than(splitter) %>%
          all %>%
          ifelse("<=" ,">") %>%
          paste(varName, ., splitter)
}

readTerminalNodePaths <- function (ct, dts) {

  nodeWeights <- function(Node) nodes(ct, Node)[[1]]$weights
  sgmnts <- ct %>% where %>% unique
  nodesFirstTreeWeightIsOne <- function(node) nodes(ct, node)[[1]][2][[1]] == 1

  # Take the inner nodes smaller than the selected terminal node
  innerNodes <-
    function(Node) setdiff( 1:(Node - 1)
                           ,sgmnts[sgmnts < Node])
  pathForTerminalNode <- function(terminalNode){
    innerNodes(terminalNode) %>%
      sapply(function(innerNode){
        if(any(nodeWeights(terminalNode) & nodesFirstTreeWeightIsOne(innerNode))) innerNode
       }) %>%
      unlist
  }

  # Find the splits criteria
  sgmnts %>% sapply(function(terminalNode){ #

    path <- terminalNode %>% pathForTerminalNode

    path %>% length %>% seq %>%
      sapply(function(nodeNumber){
        dataFilter(ct, dts, path, terminalNode, nodeNumber)
       }, simplify = FALSE) %>%
      unlist %>% paste(collapse = " & ") %>%
      data.frame(Node = terminalNode, Path = .)

  }, simplify = FALSE) %>%
    Reduce(f = rbind)
}

Testing

shiftFirstPart <- function(vctr, divideBy, proportion = .5){
    vctr[vctr %>% length %>% multiply_by(proportion) %>% round %>% seq] %<>% divide_by(divideBy)
  vctr
}
set.seed(11)
n <- 13000
gdt <- 
  data.frame( is_buyer = runif(n) %>% shiftFirstPart(1.5) %>% round %>% factor(labels = c("no", "yes"))
             ,age = runif(n) %>% shiftFirstPart(1.5) %>%
               cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, ordered_result = TRUE, labels = c("low", "mid", "high"))
             ,city = runif(n) %>% shiftFirstPart(1.5) %>%
               cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, labels = c("Chigaco", "Boston", "Memphis"))
             ,point = runif(n) %>% shiftFirstPart(1.2)
             )

gct <- ctree( is_buyer ~ ., data = gdt)
readTerminalNodePaths(gct, gdt)


来源:https://stackoverflow.com/questions/21443203/ctree-how-to-get-the-list-of-splitting-conditions-for-each-terminal-node

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!