Scala recursion vs loop: performance and runtime considerations

前端 未结 3 1982
时光取名叫无心
时光取名叫无心 2021-02-06 01:17

I\'ve wrote a naïve test-bed to measure the performance of three kinds of factorial implementation: loop based, non tail-recursive and tail-recursive.

Su

相关标签:
3条回答
  • 2021-02-06 02:07

    I know everyone already answered the question, but I thought that I might add this one optimization: If you convert the pattern matching to simple if-statements, it can speed up the tail recursion.

    final object Factorial {
      type Out = BigInt
    
      def calculateByRecursion(n: Int): Out = {
        require(n>0, "n must be positive")
    
        n match {
          case _ if n == 1 => return 1
          case _ => return n * calculateByRecursion(n-1)
        }
      }
    
      def calculateByForLoop(n: Int): Out = {
        require(n>0, "n must be positive")
    
        var accumulator: Out = 1
        for (i <- 1 to n)
          accumulator = i * accumulator
        accumulator
      }
    
      def calculateByWhileLoop(n: Int): Out = {
        require(n>0, "n must be positive")
    
        var acc: Out = 1
        var i = 1
        while (i <= n) {
          acc = i * acc
          i += 1
        }
        acc
      }
    
      def calculateByTailRecursion(n: Int): Out = {
        require(n>0, "n must be positive")
    
        @annotation.tailrec
        def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)
    
        fac(n, 1)
      }
    
      def calculateByTailRecursionUpward(n: Int): Out = {
        require(n>0, "n must be positive")
    
        @annotation.tailrec
        def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)
    
        fac(1, 1)
      }
    
      def attempt(f: ()=>Unit): Boolean = {
        try {
            f()
            true
        } catch {
            case _: Throwable =>
                println(" <<<<< Failed...")
                false
        }
      }
    
      def comparePerformance(n: Int) {
        def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
          showOutput match {
            case true =>
                val res = data._2.toString
                val pref = res.substring(0,5)
                val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
                val suff = res.substring(res.length-5)
                printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
            case false => 
                printf("%s in %d ms\n", msg, data._1)
        }
        def measure[A](f:()=>A): (Long, A) = {
          val start = System.currentTimeMillis
          val o = f()
          (System.currentTimeMillis - start, o)
        }
        attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
        attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
        attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
        attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
        attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
      }
    }
    

    My results:

    scala> Factorial.comparePerformance(20000)
    By for loop returned 18192...5708616582...00000 in 179 ms
    By while loop returned 18192...5708616582...00000 in 159 ms
    By non-tail recursion <<<<< Failed...
    By tail recursion returned 18192...5708616582...00000 in 169 ms
    By tail recursion upward returned 18192...5708616582...00000 in 174 ms
    
    By for loop returned 18192...5708616582...00000 in 212 ms
    By while loop returned 18192...5708616582...00000 in 156 ms
    By non-tail recursion returned 18192...5708616582...00000 in 155 ms
    By tail recursion returned 18192...5708616582...00000 in 166 ms
    By tail recursion upward returned 18192...5708616582...00000 in 137 ms
    
    scala> Factorial.comparePerformance(200000)
    By for loop returned 14202...0169293868...00000 in 17467 ms
    By while loop returned 14202...0169293868...00000 in 17303 ms
    By non-tail recursion <<<<< Failed...
    By tail recursion returned 14202...0169293868...00000 in 18477 ms
    By tail recursion upward returned 14202...0169293868...00000 in 17188 ms
    
    0 讨论(0)
  • 2021-02-06 02:20

    For loops are not actually quite loops; they're for comprehensions on a range. If you actually want a loop, you need to use while. (Actually, I think the BigInt multiplication here is heavyweight enough so it shouldn't matter. But you'll notice if you're multiplying Ints.)

    Also, you have confused yourself by using BigInt. The bigger your BigInt is, the slower your multiplication. So your non-tail loop counts up while your tail recursion loop counds down which means that the latter has more big numbers to multiply.

    If you fix these two issues you will find that sanity is restored: loops and tail recursion are the same speed, with both regular recursion and for slower. (Regular recursion may not be slower if the JVM optimization makes it equivalent)

    (Also, the stack overflow fix is probably because the JVM starts inlining and may either make the call tail-recursive itself, or unrolls the loop far enough so that you don't overflow any longer.)

    Finally, you're getting poor results with for and while because you're multiplying on the right rather than the left with the small number. It turns out that the Java's BigInt multiplies faster with the smaller number on the left.

    0 讨论(0)
  • 2021-02-06 02:21

    Scala static methods for factorial(n) (coded with scala 2.12.x, java-8):

    object Factorial {
    
      /*
       * For large N, it throws a stack overflow
       */
      def recursive(n:BigInt): BigInt = {
        if(n < 0) {
          throw new ArithmeticException
        } else if(n <= 1) {
          1
        } else {
          n * recursive(n - 1)
        }
      }
    
      /*
       * A tail recursive method is compiled to avoid stack overflow
       */
      @scala.annotation.tailrec
      def recursiveTail(n:BigInt, acc:BigInt = 1): BigInt = {
        if(n < 0) {
          throw new ArithmeticException
        } else if(n <= 1) {
          acc
        } else {
          recursiveTail(n - 1, n * acc)
        }
      }
    
      /*
       * A while loop
       */
      def loop(n:BigInt): BigInt = {
        if(n < 0) {
          throw new ArithmeticException
        } else if(n <= 1) {
          1
        } else {
          var acc = 1
          var idx = 1
          while(idx <= n) {
            acc = idx * acc
            idx += 1
          }
          acc
        }
      }
    
    }
    

    Specs:

    class FactorialSpecs extends SpecHelper {
    
      private val smallInt = 10
      private val largeInt = 10000
    
      describe("Factorial.recursive") {
        it("return 1 for 0") {
          assert(Factorial.recursive(0) == 1)
        }
        it("return 1 for 1") {
          assert(Factorial.recursive(1) == 1)
        }
        it("return 2 for 2") {
          assert(Factorial.recursive(2) == 2)
        }
        it("returns a result, for small inputs") {
          assert(Factorial.recursive(smallInt) == 3628800)
        }
        it("throws StackOverflow for large inputs") {
          intercept[java.lang.StackOverflowError] {
            Factorial.recursive(Int.MaxValue)
          }
        }
      }
    
      describe("Factorial.recursiveTail") {
        it("return 1 for 0") {
          assert(Factorial.recursiveTail(0) == 1)
        }
        it("return 1 for 1") {
          assert(Factorial.recursiveTail(1) == 1)
        }
        it("return 2 for 2") {
          assert(Factorial.recursiveTail(2) == 2)
        }
        it("returns a result, for small inputs") {
          assert(Factorial.recursiveTail(smallInt) == 3628800)
        }
        it("returns a result, for large inputs") {
          assert(Factorial.recursiveTail(largeInt).isInstanceOf[BigInt])
        }
      }
    
      describe("Factorial.loop") {
        it("return 1 for 0") {
          assert(Factorial.loop(0) == 1)
        }
        it("return 1 for 1") {
          assert(Factorial.loop(1) == 1)
        }
        it("return 2 for 2") {
          assert(Factorial.loop(2) == 2)
        }
        it("returns a result, for small inputs") {
          assert(Factorial.loop(smallInt) == 3628800)
        }
        it("returns a result, for large inputs") {
          assert(Factorial.loop(largeInt).isInstanceOf[BigInt])
        }
      }
    }
    

    Benchmarks:

    import org.scalameter.api._
    
    class BenchmarkFactorials extends Bench.OfflineReport {
    
      val gen: Gen[Int] = Gen.range("N")(1, 1000, 100) // scalastyle:ignore
    
      performance of "Factorial" in {
        measure method "loop" in {
          using(gen) in {
            n => Factorial.loop(n)
          }
        }
        measure method "recursive" in {
          using(gen) in {
            n => Factorial.recursive(n)
          }
        }
        measure method "recursiveTail" in {
          using(gen) in {
            n => Factorial.recursiveTail(n)
          }
        }
      }
    
    }
    

    Benchmark results (loop is much faster):

    [info] Test group: Factorial.loop
    [info] - Factorial.loop.Test-9 measurements:
    [info]   - at N -> 1: passed
    [info]     (mean = 0.01 ms, ci = <0.00 ms, 0.02 ms>, significance = 1.0E-10)
    [info]   - at N -> 101: passed
    [info]     (mean = 0.01 ms, ci = <0.01 ms, 0.01 ms>, significance = 1.0E-10)
    [info]   - at N -> 201: passed
    [info]     (mean = 0.02 ms, ci = <0.02 ms, 0.02 ms>, significance = 1.0E-10)
    [info]   - at N -> 301: passed
    [info]     (mean = 0.03 ms, ci = <0.02 ms, 0.03 ms>, significance = 1.0E-10)
    [info]   - at N -> 401: passed
    [info]     (mean = 0.03 ms, ci = <0.03 ms, 0.04 ms>, significance = 1.0E-10)
    [info]   - at N -> 501: passed
    [info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
    [info]   - at N -> 601: passed
    [info]     (mean = 0.04 ms, ci = <0.04 ms, 0.05 ms>, significance = 1.0E-10)
    [info]   - at N -> 701: passed
    [info]     (mean = 0.05 ms, ci = <0.05 ms, 0.05 ms>, significance = 1.0E-10)
    [info]   - at N -> 801: passed
    [info]     (mean = 0.06 ms, ci = <0.05 ms, 0.06 ms>, significance = 1.0E-10)
    [info]   - at N -> 901: passed
    [info]     (mean = 0.06 ms, ci = <0.05 ms, 0.07 ms>, significance = 1.0E-10)
    
    [info] Test group: Factorial.recursive
    [info] - Factorial.recursive.Test-10 measurements:
    [info]   - at N -> 1: passed
    [info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
    [info]   - at N -> 101: passed
    [info]     (mean = 0.05 ms, ci = <0.01 ms, 0.09 ms>, significance = 1.0E-10)
    [info]   - at N -> 201: passed
    [info]     (mean = 0.03 ms, ci = <0.02 ms, 0.05 ms>, significance = 1.0E-10)
    [info]   - at N -> 301: passed
    [info]     (mean = 0.07 ms, ci = <0.00 ms, 0.13 ms>, significance = 1.0E-10)
    [info]   - at N -> 401: passed
    [info]     (mean = 0.09 ms, ci = <0.01 ms, 0.18 ms>, significance = 1.0E-10)
    [info]   - at N -> 501: passed
    [info]     (mean = 0.10 ms, ci = <0.03 ms, 0.17 ms>, significance = 1.0E-10)
    [info]   - at N -> 601: passed
    [info]     (mean = 0.11 ms, ci = <0.08 ms, 0.15 ms>, significance = 1.0E-10)
    [info]   - at N -> 701: passed
    [info]     (mean = 0.13 ms, ci = <0.11 ms, 0.14 ms>, significance = 1.0E-10)
    [info]   - at N -> 801: passed
    [info]     (mean = 0.16 ms, ci = <0.13 ms, 0.19 ms>, significance = 1.0E-10)
    [info]   - at N -> 901: passed
    [info]     (mean = 0.21 ms, ci = <0.15 ms, 0.27 ms>, significance = 1.0E-10)
    
    [info] Test group: Factorial.recursiveTail
    [info] - Factorial.recursiveTail.Test-11 measurements:
    [info]   - at N -> 1: passed
    [info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
    [info]   - at N -> 101: passed
    [info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
    [info]   - at N -> 201: passed
    [info]     (mean = 0.12 ms, ci = <0.05 ms, 0.20 ms>, significance = 1.0E-10)
    [info]   - at N -> 301: passed
    [info]     (mean = 0.16 ms, ci = <-0.03 ms, 0.34 ms>, significance = 1.0E-10)
    [info]   - at N -> 401: passed
    [info]     (mean = 0.12 ms, ci = <0.09 ms, 0.16 ms>, significance = 1.0E-10)
    [info]   - at N -> 501: passed
    [info]     (mean = 0.17 ms, ci = <0.15 ms, 0.19 ms>, significance = 1.0E-10)
    [info]   - at N -> 601: passed
    [info]     (mean = 0.23 ms, ci = <0.19 ms, 0.26 ms>, significance = 1.0E-10)
    [info]   - at N -> 701: passed
    [info]     (mean = 0.25 ms, ci = <0.18 ms, 0.32 ms>, significance = 1.0E-10)
    [info]   - at N -> 801: passed
    [info]     (mean = 0.28 ms, ci = <0.21 ms, 0.36 ms>, significance = 1.0E-10)
    [info]   - at N -> 901: passed
    [info]     (mean = 0.32 ms, ci = <0.17 ms, 0.46 ms>, significance = 1.0E-10)
    
    0 讨论(0)
提交回复
热议问题