Testing rules generated by Rpart package

后端 未结 2 707
不知归路
不知归路 2021-02-09 08:42

I want to test in a programmatically way one rule generated from a tree. In the trees the path between the root and a leaf (terminal node) could be interpreted as a rule.

<
相关标签:
2条回答
  • 2021-02-09 09:22

    In general I don't recommend using eval(parse(...)) but in this case it seems to work:

    Extract the rule:

    rule <- unname(unlist(path.rpart(model, nodes=7)))[-1]
    
     node number: 7 
       root
       Petal.Length>=2.45
       Petal.Width>=1.75
    rule
    [1] "Petal.Length>=2.45" "Petal.Width>=1.75" 
    

    Extract the data using the rule:

    node_data <- with(iris, iris[eval(parse(text=paste(rule, collapse=" & "))), ])
    head(node_data)
    
        Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
    71           5.9         3.2          4.8         1.8 versicolor
    101          6.3         3.3          6.0         2.5  virginica
    102          5.8         2.7          5.1         1.9  virginica
    103          7.1         3.0          5.9         2.1  virginica
    104          6.3         2.9          5.6         1.8  virginica
    105          6.5         3.0          5.8         2.2  virginica
    
    0 讨论(0)
  • 2021-02-09 09:26

    I could solve this in the following way

    DISCLAIMER: Obviously must be better ways of solving this, but this hacks works and do what I want... (I am not very proud of it...is hackish, but works)

    Ok, lets start. Basically the idea is using the package sqldf

    If you check the question, the last piece of code, puts in a list every piece of the path of the tree. So, I will start from there

            library(sqldf)
            library(stringr)
    
            # Transform to a character vector
            rule.v <- unlist(rule, use.names=FALSE)[-1]
            # Remove all the dots, sqldf doesn't handles dots in names 
            rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
            # We have to remove all the equal signs to 'in ('
            rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
            # Embrace all the elements in the lists of values with " ' " 
            # The last element couldn't be modified in this way (Any ideas?) 
            rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")
    
            # Close the last element with apostrophe and a ")" 
            for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
              rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
            }
    
            # Collapse all the list in one string joined by " AND "
            rule.v <- paste(rule.v, collapse = " AND ")
    
            # Generate the query
            # Use any metric that you can get from the data frame
            query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")
    
            # For debug only...
            print(query)
    
            # Execute and print the results
            print(sqldf(query))
    

    And that's all!

    I warned you, It was hackish...

    Hopefully this helps someone else ...

    Thanks for all the help and suggestions!

    0 讨论(0)
提交回复
热议问题