Testing rules generated by Rpart package

 ̄綄美尐妖づ 提交于 2019-12-03 07:24:31

I could solve this in the following way

DISCLAIMER: Obviously must be better ways of solving this, but this hacks works and do what I want... (I am not very proud of it...is hackish, but works)

Ok, lets start. Basically the idea is using the package sqldf

If you check the question, the last piece of code, puts in a list every piece of the path of the tree. So, I will start from there

        library(sqldf)
        library(stringr)

        # Transform to a character vector
        rule.v <- unlist(rule, use.names=FALSE)[-1]
        # Remove all the dots, sqldf doesn't handles dots in names 
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
        # We have to remove all the equal signs to 'in ('
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
        # Embrace all the elements in the lists of values with " ' " 
        # The last element couldn't be modified in this way (Any ideas?) 
        rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")

        # Close the last element with apostrophe and a ")" 
        for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
          rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
        }

        # Collapse all the list in one string joined by " AND "
        rule.v <- paste(rule.v, collapse = " AND ")

        # Generate the query
        # Use any metric that you can get from the data frame
        query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")

        # For debug only...
        print(query)

        # Execute and print the results
        print(sqldf(query))

And that's all!

I warned you, It was hackish...

Hopefully this helps someone else ...

Thanks for all the help and suggestions!

In general I don't recommend using eval(parse(...)) but in this case it seems to work:

Extract the rule:

rule <- unname(unlist(path.rpart(model, nodes=7)))[-1]

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75
rule
[1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

Extract the data using the rule:

node_data <- with(iris, iris[eval(parse(text=paste(rule, collapse=" & "))), ])
head(node_data)

    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
71           5.9         3.2          4.8         1.8 versicolor
101          6.3         3.3          6.0         2.5  virginica
102          5.8         2.7          5.1         1.9  virginica
103          7.1         3.0          5.9         2.1  virginica
104          6.3         2.9          5.6         1.8  virginica
105          6.5         3.0          5.8         2.2  virginica

Starting with

Rule number: 16 [yval=bad cover=220 N=121 Y=99 (37%) prob=0.04]
checking< 2.5
afford< 54
history< 3.5
coapp< 2.5

You would have a 'prob' vector that started out as all zeros, that you could update with rule16:

prob <- ifelse( dat[['checking']] < 2.5 &
                dat[['afford']]  < 54
                dat[['history']] < 3.5
                dat[['coapp']]  < 2.5) , 0.04, prob )

You would then need to run through all the other rules (which should not change any probabilities for this case since the tree should be disjoint estimates.) There are likely to be much more efficient methods than this for constructing predictions. For instance ... the predict.rpart function.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!