问题
I'm interested in writing up a recursive binary tree algorithm. Given the following data where I've already sorted the covariate x
mydata <- data.frame(x = c(10, 20, 25, 35), y = c(-10.5, 6.5, 7.5, -7.5))
> mydata
x y
1 10 -10.5
2 20 6.5
3 25 7.5
4 35 -7.5
Suppose my final tree looks something like this:
[-10.5, 6.5, 7.5, -7.5]
/ \
[-10.5] [6.5, 7.5, -7.5]
/ \
[6.5, 7.5] [ -7.5]
I want the final output of my function to return a list that contains all of the nodes:
> final_tree
[[1]]
[[1]][[1]]
x y
1 10 -10.5
2 20 6.5
3 25 7.5
4 35 -7.5
[[2]]
[[2]][[1]]
x y
1 10 -10.5
[[2]][[2]]
x y
1 20 6.5
2 25 7.5
3 35 -7.5
[[3]]
[[3]][[1]]
NULL
[[3]][[2]]
NULL
[[3]][[3]]
x y
1 20 6.5
2 25 7.5
[[3]][[4]]
x y
1 35 -7.5
I am splitting my tree at every node with a random split by using best_split_ind
. If best_split_ind = 1
, then that means the 1st instance in the node_parent
will end up in the node_left
, and the rest end up in node_right
. If best_split_ind = 3
, then that means the first three instances in the node_parent
will end up in the node_left
, and the rest end up in node_right
.
Here's what I have so far:
# Initialize empty tree
create_empty_tree <- function(max_height) sapply(1:max_height, function(k) replicate(2**(k-1),c()))
# Create empty tree with max_height = 3
tree_struc <- create_empty_tree(max_height = 3)
grow_tree <- function(node_parent, max_height, tree_struc, height){
# Sort x
sorted_x <- sort(node_parent$x)
# Determine best split
best_split_ind <- sample(1:(nrow(node_parent) - 1), 1)
# Assign instances to left or right nodes
group <- ifelse(node_parent$x <= node_parent$x[best_split_ind], "left", "right")
node_left <- node_parent[which(group == "left"), ]
node_right <- node_parent[which(group == "right"), ]
# Recursive call on left and right nodes
if(height < max_height){
tree_struc[[height]] <- node_parent
tree_struc[[height + 1]][[1]] <- grow_tree(node_parent = node_left, max_height = max_height, tree_struc = tree_struc, height = height + 1)
tree_struc[[height + 1]][[2]] <- grow_tree(node_parent = node_right, max_height = max_height, tree_struc = tree_struc, height = height + 1)
}
return(tree_struc)
}
grow_tree(node_parent = mydata, max_height = 3, tree_struc = tree_struc, height = 1)
The resulting tree is not correct. I think it has to do with how I recursively called the function on the left and right child nodes. Can anyone point me in the right direction?
回答1:
I may have misunderstood you, but you can simplify quite a bit here by using two functions that call each other recursively. There's no need to set up an initial container.
The first function is one that we don't even need to call manually, but will be called from inside our grow_tree
function. It simply checks that it has not reached the maximum tree depth and that there are enough elements left to split. If so, it calls grow_tree
on its contents. Otherwise, it returns its contents unchanged:
conditional_split <- function(df, depth, max_depth)
{
if(nrow(df) == 1 | depth == max_depth) return(df)
else grow_tree(df, depth + 1, max_depth)
}
Our main function can then safely split the given data frame and recursively call conditional_split
with lapply
:
grow_tree <- function(df, depth = 1, max_depth = 3)
{
break_at <- sample(nrow(df) - 1, 1)
branched <- list(left = df[1:break_at,], right = df[-seq(break_at),])
lapply(branched, conditional_split, depth, max_depth)
}
I think this does what you're looking for:
grow_tree(mydata, max_depth = 3)
#> $left
#> x y
#> 1 10 -10.5
#>
#> $right
#> $right$left
#> $right$left$left
#> x y
#> 2 20 6.5
#>
#> $right$left$right
#> x y
#> 3 25 7.5
#>
#>
#> $right$right
#> x y
#> 4 35 -7.5
And you can change the maximum tree depth as easily as:
grow_tree(mydata, max_depth = 2)
#> $left
#> $left$left
#> x y
#> 1 10 -10.5
#>
#> $left$right
#> x y
#> 2 20 6.5
#> 3 25 7.5
#>
#>
#> $right
#> x y
#> 4 35 -7.5
回答2:
Maybe you can try the code below, where another custom function rndsplit
was defined within grow_tree
:
create_empty_tree <- function(max_height) sapply(1:max_height, function(k) replicate(2**(k-1),c()))
grow_tree <- function(node_parent,max_height = nrow(node_parent)) {
rndsplit <- function(x) {
if (is.null(x) || nrow(x) <= 1) return(list(c(),c()))
ind <- sample(nrow(x)-1,1)
list(x[1:ind,],x[-(1:ind),])
}
tree_struc <- create_empty_tree(max_height)
tree_struc[[1]][[1]] <- node_parent
for (i in 2:max_height) {
tree_struc[[i]] <- unlist(lapply(tree_struc[[i-1]], rndsplit),recursive = FALSE)
}
tree_struc
}
Example
> grow_tree(mydata,3)
[[1]]
[[1]][[1]]
x y
1 10 -10.5
2 20 6.5
3 25 7.5
4 35 -7.5
[[2]]
[[2]][[1]]
x y
1 10 -10.5
2 20 6.5
[[2]][[2]]
x y
3 25 7.5
4 35 -7.5
[[3]]
[[3]][[1]]
x y
1 10 -10.5
[[3]][[2]]
x y
2 20 6.5
[[3]][[3]]
x y
3 25 7.5
[[3]][[4]]
x y
4 35 -7.5
and
> grow_tree(mydata)
[[1]]
[[1]][[1]]
x y
1 10 -10.5
2 20 6.5
3 25 7.5
4 35 -7.5
[[2]]
[[2]][[1]]
x y
1 10 -10.5
[[2]][[2]]
x y
2 20 6.5
3 25 7.5
4 35 -7.5
[[3]]
[[3]][[1]]
NULL
[[3]][[2]]
NULL
[[3]][[3]]
x y
2 20 6.5
[[3]][[4]]
x y
3 25 7.5
4 35 -7.5
[[4]]
[[4]][[1]]
NULL
[[4]][[2]]
NULL
[[4]][[3]]
NULL
[[4]][[4]]
NULL
[[4]][[5]]
NULL
[[4]][[6]]
NULL
[[4]][[7]]
x y
3 25 7.5
[[4]][[8]]
x y
4 35 -7.5
来源:https://stackoverflow.com/questions/61621974/r-recursive-tree-algorithm-with-a-random-split