I\'m creating a tree of custom objects in scala and my insert method throws a stack overflow because it\'s not tail recursive. However, I can\'t quite figure out how to make it
Your function can't be tail recursive. The reason is that your recursive calls to insert
don't end the computation, they're used as a subexpressions, in this case in new Node(...)
. For example. if you were just searching for the bottom element, it would easy to make it tail recursive.
What's happening: As you're descending the tree down, calling insert
on each of the nodes, but you have to remember the way back to the root, since you have to reconstruct the tree after you replace a bottom leaf with your new value.
A possible solution: Remember the down path explicitly, not on stack. Let's use a simplified data structure for the example:
sealed trait Tree;
case object EmptyTree extends Tree;
case class Node(elem: Int, left:Tree, right:Tree) extends Tree;
Now define what a path is: It's a list of nodes together with the information if we went right or left. The root is always at the end of the list, the leaf at the start.
type Path = List[(Node, Boolean)]
Now we can make a tail recursive function that computes a path given a value:
// Find a path down the tree that leads to the leaf where `v` would belong.
private def path(tree: Tree, v: Int): Path = {
@tailrec
def loop(t: Tree, p: Path): Path =
t match {
case EmptyTree => p
case n@Node(w, l, r) =>
if (v < w) loop(l, (n, false) :: p)
else loop(r, (n, true) :: p)
}
loop(tree, Nil)
}
and a function that takes a path and a value and reconstructs a new tree with the value as a new node at the bottom of the path:
// Given a path reconstruct a new tree with `v` inserted at the bottom
// of the path.
private def rebuild(path: Path, v: Int): Tree = {
@tailrec
def loop(p: Path, subtree: Tree): Tree =
p match {
case Nil => subtree
case (Node(w, l, r), false) :: q => loop(q, Node(w, subtree, r))
case (Node(w, l, r), true) :: q => loop(q, Node(w, l, subtree))
}
loop(path, Node(v, EmptyTree, EmptyTree))
}
Inserting is then easy:
def insert(tree: Tree, v: Int): Tree =
rebuild(path(tree, v), v)
Note that this version isn't particularly efficient. Probably you could make it more efficient using Seq
, or even further by using a mutable stack to store the path. But with List
the idea can be expressed nicely.
Disclaimer: I only compiled the code, I haven't tested it at all.
Notes: