I\'ve wrote a naïve test-bed to measure the performance of three kinds of factorial implementation: loop based, non tail-recursive and tail-recursive.
Su
I know everyone already answered the question, but I thought that I might add this one optimization: If you convert the pattern matching to simple if-statements, it can speed up the tail recursion.
final object Factorial {
type Out = BigInt
def calculateByRecursion(n: Int): Out = {
require(n>0, "n must be positive")
n match {
case _ if n == 1 => return 1
case _ => return n * calculateByRecursion(n-1)
}
}
def calculateByForLoop(n: Int): Out = {
require(n>0, "n must be positive")
var accumulator: Out = 1
for (i <- 1 to n)
accumulator = i * accumulator
accumulator
}
def calculateByWhileLoop(n: Int): Out = {
require(n>0, "n must be positive")
var acc: Out = 1
var i = 1
while (i <= n) {
acc = i * acc
i += 1
}
acc
}
def calculateByTailRecursion(n: Int): Out = {
require(n>0, "n must be positive")
@annotation.tailrec
def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)
fac(n, 1)
}
def calculateByTailRecursionUpward(n: Int): Out = {
require(n>0, "n must be positive")
@annotation.tailrec
def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)
fac(1, 1)
}
def attempt(f: ()=>Unit): Boolean = {
try {
f()
true
} catch {
case _: Throwable =>
println(" <<<<< Failed...")
false
}
}
def comparePerformance(n: Int) {
def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
showOutput match {
case true =>
val res = data._2.toString
val pref = res.substring(0,5)
val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
val suff = res.substring(res.length-5)
printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
case false =>
printf("%s in %d ms\n", msg, data._1)
}
def measure[A](f:()=>A): (Long, A) = {
val start = System.currentTimeMillis
val o = f()
(System.currentTimeMillis - start, o)
}
attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
}
}
My results:
scala> Factorial.comparePerformance(20000)
By for loop returned 18192...5708616582...00000 in 179 ms
By while loop returned 18192...5708616582...00000 in 159 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 18192...5708616582...00000 in 169 ms
By tail recursion upward returned 18192...5708616582...00000 in 174 ms
By for loop returned 18192...5708616582...00000 in 212 ms
By while loop returned 18192...5708616582...00000 in 156 ms
By non-tail recursion returned 18192...5708616582...00000 in 155 ms
By tail recursion returned 18192...5708616582...00000 in 166 ms
By tail recursion upward returned 18192...5708616582...00000 in 137 ms
scala> Factorial.comparePerformance(200000)
By for loop returned 14202...0169293868...00000 in 17467 ms
By while loop returned 14202...0169293868...00000 in 17303 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 14202...0169293868...00000 in 18477 ms
By tail recursion upward returned 14202...0169293868...00000 in 17188 ms