Tail recursive function to find depth of a tree in Ocaml

后端 未结 3 1941
别那么骄傲
别那么骄傲 2020-11-29 01:50

I have a type tree defined as follows

type \'a tree = Leaf of \'a | Node of \'a * \'a tree * \'a tree ;;

I have a function to

相关标签:
3条回答
  • 2020-11-29 01:59

    There's a neat and generic solution using fold_tree and CPS - continuous passing style:

    let fold_tree tree f acc =
      let loop t cont =
        match tree with
        | Leaf -> cont acc
        | Node (x, left, right) ->
          loop left (fun lacc ->
            loop right (fun racc ->
              cont @@ f x lacc racc))
      in loop tree (fun x -> x)
    
    let depth tree = fold_tree tree (fun x dl dr -> 1 + (max dl dr)) 0
    
    0 讨论(0)
  • 2020-11-29 02:18

    In this case (depth computation), you can accumulate over pairs (subtree depth * subtree content) to obtain the following tail-recursive function:

    let depth tree =
      let rec aux depth = function
        | [] -> depth
        | (d, Leaf _) :: t -> aux (max d depth) t
        | (d, Node (_,left,right)) :: t ->
          let accu = (d+1, left) :: (d+1, right) :: t in
          aux depth accu in
    aux 0 [(0, tree)]
    

    For more general cases, you will indeed need to use the CPS transformation described by Gabriel.

    0 讨论(0)
  • 2020-11-29 02:19

    You can trivially do this by turning the function into CPS (Continuation Passing Style). The idea is that instead of calling depth left, and then computing things based on this result, you call depth left (fun dleft -> ...), where the second argument is "what to compute once the result (dleft) is available".

    let depth tree =
      let rec depth tree k = match tree with
        | Leaf x -> k 0
        | Node(_,left,right) ->
          depth left (fun dleft ->
            depth right (fun dright ->
              k (1 + (max dleft dright))))
      in depth tree (fun d -> d)
    

    This is a well-known trick that can make any function tail-recursive. Voilà, it's tail-rec.

    The next well-known trick in the bag is to "defunctionalize" the CPS result. The representation of continuations (the (fun dleft -> ...) parts) as functions is neat, but you may want to see what it looks like as data. So we replace each of these closures by a concrete constructor of a datatype, that captures the free variables used in it.

    Here we have three continuation closures: (fun dleft -> depth right (fun dright -> k ...)), which only reuses the environment variables right and k, (fun dright -> ...), which reuses k and the now-available left result dleft, and (fun d -> d), the initial computation, that doesn't capture anything.

    type ('a, 'b) cont =
      | Kleft of 'a tree * ('a, 'b) cont (* right and k *)
      | Kright of 'b * ('a, 'b) cont     (* dleft and k *)
      | Kid
    

    The defunctorized function looks like this:

    let depth tree =
      let rec depth tree k = match tree with
        | Leaf x -> eval k 0
        | Node(_,left,right) ->
          depth left (Kleft(right, k))
      and eval k d = match k with
        | Kleft(right, k) ->
          depth right (Kright(d, k))
        | Kright(dleft, k) ->
          eval k (1 + max d dleft)
        | Kid -> d
      in depth tree Kid
    ;;
    

    Instead of building a function k and applying it on the leaves (k 0), I build a data of type ('a, int) cont, which needs to be later evaluated to compute a result. eval, when it gets passed a Kleft, does what the closure (fun dleft -> ...) was doing, that is it recursively call depth on the right subtree. eval and depth are mutually recursive.

    Now look hard at ('a, 'b) cont, what is this datatype? It's a list!

    type ('a, 'b) next_item =
      | Kleft of 'a tree
      | Kright of 'b
    
    type ('a, 'b) cont = ('a, 'b) next_item list
    
    let depth tree =
      let rec depth tree k = match tree with
        | Leaf x -> eval k 0
        | Node(_,left,right) ->
          depth left (Kleft(right) :: k)
      and eval k d = match k with
        | Kleft(right) :: k ->
          depth right (Kright(d) :: k)
        | Kright(dleft) :: k ->
          eval k (1 + max d dleft)
        | [] -> d
      in depth tree []
    ;;
    

    And a list is a stack. What we have here is actually a reification (transformation into data) of the call stack of the previous recursive function, with two different cases corresponding to the two different kinds of non-tailrec calls.

    Note that the defunctionalization is only there for fun. In pratice the CPS version is short, easy to derive by hand, rather easy to read, and I would recommend using it. Closures must be allocated in memory, but so are elements of ('a, 'b) cont -- albeit those might be represented more compactly`. I would stick to the CPS version unless there are very good reasons to do something more complicated.

    0 讨论(0)
提交回复
热议问题