I want to test in a programmatically way one rule generated from a tree. In the trees the path between the root and a leaf (terminal node) could be interpreted as a rule.
In R, we could use the rpart
package and do the following:
(In this post, I will use the iris
data set, for example purposes only)
library(rpart)
model <- rpart(Species ~ ., data=iris)
With this two lines I got a tree named model
, whose class is rpart.object
(rpart
documentation, page 21). This object has a lot of information, and supports a variety of methods. In particular, the object has a frame
variable (which can be accessed in the standard way: model$frame
)(idem) and the method path.rpath
(rpart
documentation, page 7), which gives you the path from the root node to the node of interest (node
argument in the function)
The row.names
of the frame
variable contains the node numbers of the tree. The var
column gives the split variable in the node, yval
the fitted value and yval2
class probabilities and other information.
> model$frame
var n wt dev yval complexity ncompete nsurrogate yval2.1 yval2.2 yval2.3 yval2.4 yval2.5 yval2.6 yval2.7
1 Petal.Length 150 150 100 1 0.50 3 3 1.00000000 50.00000000 50.00000000 50.00000000 0.33333333 0.33333333 0.33333333
2 <leaf> 50 50 0 1 0.01 0 0 1.00000000 50.00000000 0.00000000 0.00000000 1.00000000 0.00000000 0.00000000
3 Petal.Width 100 100 50 2 0.44 3 3 2.00000000 0.00000000 50.00000000 50.00000000 0.00000000 0.50000000 0.50000000
6 <leaf> 54 54 5 2 0.00 0 0 2.00000000 0.00000000 49.00000000 5.00000000 0.00000000 0.90740741 0.09259259
7 <leaf> 46 46 1 3 0.01 0 0 3.00000000 0.00000000 1.00000000 45.00000000 0.00000000 0.02173913 0.97826087
But only the marked as <leaf>
in the var
column are terminal nodes (leafs). In this case the nodes are 2, 6 and 7.
As mentioned above you can use the path.rpart
method for extract a rule (this technique is used in the rattle
package and in the article Sharma Credit Score, as follows:
Aditionally, the model keeps the values of the predicted value in
predicted.levels <- attr(model, "ylevels")
This value correspond with the column yval
in the model$frame
data set.
For the leaf with node number 7 (row number 5), the predicted value is
> ylevels[model$frame[5, ]$yval]
[1] "virginica"
and the rule is
> rule <- path.rpart(model, nodes = 7)
node number: 7
root
Petal.Length>=2.45
Petal.Width>=1.75
So, the rule could be read as
If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica
I know that I can test (in a testing data set, I will use the iris data set again) how many true positives I have for this rule, subsetting the new data set as follows
> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)
and then calculating the confusion matrix
> table(hits$Species, hits$Species == "virginica")
FALSE TRUE
setosa 0 0
versicolor 1 0
virginica 0 45
(Note: I used the same iris data set as testing)
How I could evaluate the rule in a programmatically way? I could extract the conditions from the rule as follows
> unlist(rule, use.names = FALSE)[-1]
[1] "Petal.Length>=2.45" "Petal.Width>=1.75"
But, how I can continue from here? I can not use the subset
function
Thanks in advance
NOTE: This question has been heavily edited for better clarity
I could solve this in the following way
DISCLAIMER: Obviously must be better ways of solving this, but this hacks works and do what I want... (I am not very proud of it...is hackish, but works)
Ok, lets start. Basically the idea is using the package sqldf
If you check the question, the last piece of code, puts in a list every piece of the path of the tree. So, I will start from there
library(sqldf)
library(stringr)
# Transform to a character vector
rule.v <- unlist(rule, use.names=FALSE)[-1]
# Remove all the dots, sqldf doesn't handles dots in names
rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
# We have to remove all the equal signs to 'in ('
rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
# Embrace all the elements in the lists of values with " ' "
# The last element couldn't be modified in this way (Any ideas?)
rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")
# Close the last element with apostrophe and a ")"
for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
}
# Collapse all the list in one string joined by " AND "
rule.v <- paste(rule.v, collapse = " AND ")
# Generate the query
# Use any metric that you can get from the data frame
query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")
# For debug only...
print(query)
# Execute and print the results
print(sqldf(query))
And that's all!
I warned you, It was hackish...
Hopefully this helps someone else ...
Thanks for all the help and suggestions!
In general I don't recommend using eval(parse(...))
but in this case it seems to work:
Extract the rule:
rule <- unname(unlist(path.rpart(model, nodes=7)))[-1]
node number: 7
root
Petal.Length>=2.45
Petal.Width>=1.75
rule
[1] "Petal.Length>=2.45" "Petal.Width>=1.75"
Extract the data using the rule:
node_data <- with(iris, iris[eval(parse(text=paste(rule, collapse=" & "))), ])
head(node_data)
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
71 5.9 3.2 4.8 1.8 versicolor
101 6.3 3.3 6.0 2.5 virginica
102 5.8 2.7 5.1 1.9 virginica
103 7.1 3.0 5.9 2.1 virginica
104 6.3 2.9 5.6 1.8 virginica
105 6.5 3.0 5.8 2.2 virginica
Starting with
Rule number: 16 [yval=bad cover=220 N=121 Y=99 (37%) prob=0.04]
checking< 2.5
afford< 54
history< 3.5
coapp< 2.5
You would have a 'prob' vector that started out as all zeros, that you could update with rule16:
prob <- ifelse( dat[['checking']] < 2.5 &
dat[['afford']] < 54
dat[['history']] < 3.5
dat[['coapp']] < 2.5) , 0.04, prob )
You would then need to run through all the other rules (which should not change any probabilities for this case since the tree should be disjoint estimates.) There are likely to be much more efficient methods than this for constructing predictions. For instance ... the predict.rpart
function.
来源:https://stackoverflow.com/questions/11831794/testing-rules-generated-by-rpart-package