R: Obtaining Rules from a Function

蓝咒 提交于 2021-01-30 09:09:58

问题


I am using the R programming language. I used the "rpart" library and fit a decision tree using some data:

#from a previous question : https://stackoverflow.com/questions/65678552/r-changing-plot-sizes 

    library(rpart)

   car.test.frame$Reliability = as.factor(car.test.frame$Reliability)
    
    z.auto <- rpart(Reliability ~ ., car.test.frame)
    plot(z.auto)
    text(z.auto, use.n=TRUE, xpd=TRUE, cex=.8)

This is good, but I am looking for an easier way to summarize the results of this tree in case the tree becomes too big, complicated and cluttered (and impossible to visualize). I found another stackoverflow post over here that shows how to obtain a listing of rules: Extracting Information from the Decision Rules in rpart package

library(party)
library(partykit)

party_obj <- as.party.rpart(z.auto, data = TRUE)
decisions <- partykit:::.list.rules.party(party_obj)
cat(paste(decisions, collapse = "\n"))

This returns the following list of rules (each line is a rule corresponding to the plot of "z.auto"):

    Country %in% c("NA", "Germany", "Korea", "Mexico", "Sweden", "USA") & Weight >= 3167.5
Country %in% c("NA", "Germany", "Korea", "Mexico", "Sweden", "USA") & Weight < 3167.5
Country %in% c("NA", "Japan", "Japan/USA")> 

However, from this list, it is not possible to know which rule results in which value of "Reliability". For the time being, I am manually interpreting the tree and manually tracing each rule to the result, but is there a way to add to each line "the corresponding value of reliability"?

e.g. Is it possible to produce something like this?

Country %in% c("NA", "Germany", "Korea", "Mexico", "Sweden", "USA") & Weight >= 3167.5 then reliability = 3,7,4,0

(note1: I am also not sure why the countries are appearing as "befgh" instead of their actual names.

note2: I am aware that there is a library "rpart.plot" that has a simpler way of obtaining these rules. However, I am using a computer that does not have internet access or a usb port, therefore I can not download the rpart.plot library. I have R with a few preloaded packages. I am trying to obtain the decision rules using libraries such as rpart, dplyr, purr, party, partykit, functions from base R)

Thanks


回答1:


This isn't my area of expertise, but perhaps this function (from https://www.togaware.com/datamining/survivor/Convert_Tree.html) will do what you want to do:

library(rpart)
car.test.frame$Reliability = as.factor(car.test.frame$Reliability)
z.auto <- rpart(Reliability ~ ., car.test.frame)
plot(z.auto, margin = 0.25)
text(z.auto, pretty = TRUE, cex = 0.8,
     splits = TRUE, use.n = TRUE, all = FALSE)

list.rules.rpart <- function(model)
{
  if (!inherits(model, "rpart")) stop("Not a legitimate rpart tree")
  #
  # Get some information.
  #
  frm     <- model$frame
  names   <- row.names(frm)
  ylevels <- attr(model, "ylevels")
  ds.size <- model$frame[1,]$n
  #
  # Print each leaf node as a rule.
  #
  for (i in 1:nrow(frm))
  {
    if (frm[i,1] == "<leaf>")
    {
      # The following [,5] is hardwired - needs work!
      cat("\n")
      cat(sprintf(" Rule number: %s ", names[i]))
      cat(sprintf("[yval=%s cover=%d (%.0f%%) prob=%0.2f]\n",
                  ylevels[frm[i,]$yval], frm[i,]$n,
                  round(100*frm[i,]$n/ds.size), frm[i,]$yval2[,5]))
      pth <- path.rpart(model, nodes=as.numeric(names[i]), print.it=FALSE)
      cat(sprintf("   %s\n", unlist(pth)[-1]), sep="")
    }
  }
}

list.rules.rpart(z.auto)
>Rule number: 4 [yval=3 cover=10 (20%) prob=0.00]
>   Country=Germany,Korea,Mexico,Sweden,USA
>   Weight>=3168
>
> Rule number: 5 [yval=2 cover=18 (37%) prob=4.00]
>   Country=Germany,Korea,Mexico,Sweden,USA
>   Weight< 3168
>
> Rule number: 3 [yval=5 cover=21 (43%) prob=2.00]
>   Country=Japan,Japan/USA



来源:https://stackoverflow.com/questions/65679523/r-obtaining-rules-from-a-function

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!