Scala tail recursive method has an divide and remainder error

旧街凉风 提交于 2021-02-19 07:47:28


I'm currently computing the binomial coefficient of two natural numbers by write a tail recursion in Scala. But my code has something wrong with the dividing numbers, integer division by k like I did as that will give you a non-zero remainder and hence introduce rounding errors. So could anyone help me figure it out, how to fix it ?

 def binom(n: Int, k: Int): Int = {
    require(0 <= k && k <= n)
    def binomtail(n: Int, k: Int, ac: Int): Int = {
      if (n == k || k == 0) ac
      else binomtail(n - 1, k - 1, (n*ac)/k)


In general, it holds:

binom(n, k) = if (k == 0 || k == n) 1 else binom(n - 1, k - 1) * n / k

If you want to compute it in linear time, then you have to make sure that each intermediate result is an integer. Now,

binom(n - k + 1, 1)

is certainly an integer (it's just n - k + 1). Starting with this number, and incrementing both arguments by one, you can arrive at binom(n, k) with the following intermediate steps:

binom(n - k + 1, 1)
binom(n - k + 2, 2)
binom(n - 2, k - 2)
binom(n - 1, k - 1)
binom(n, k)

It means that you have to "accumulate" in the right order, from 1 up to k, not from k down to 1 - then it is guaranteed that all intermediate results correspond to actual binomial coefficients, and are therefore integers (not fractions). Here is what it looks like as tail-recursive function:

def binom(n: Int, k: Int): Int = {
  require(0 <= k && k <= n)
  def binomtail(nIter: Int, kIter: Int, ac: Int): Int = {
    if (kIter > k) ac
    else binomtail(nIter + 1, kIter + 1, (nIter * ac) / kIter)
  if (k == 0 || k == n) 1
  else binomtail(n - k + 1, 1, 1)

Little visual test:

val n = 12
for (i <- 0 to n) {
  print(" " * ((n - i) * 2))
  for (j <- 0 to i) {
    printf(" %3d", binom(i, j))


                         1   1
                       1   2   1
                     1   3   3   1
                   1   4   6   4   1
                 1   5  10  10   5   1
               1   6  15  20  15   6   1
             1   7  21  35  35  21   7   1
           1   8  28  56  70  56  28   8   1
         1   9  36  84 126 126  84  36   9   1
       1  10  45 120 210 252 210 120  45  10   1
     1  11  55 165 330 462 462 330 165  55  11   1
   1  12  66 220 495 792 924 792 495 220  66  12   1

Looks ok, compare it with this, if you want.


Andrey Tyukin's excellent example will fail with larger n, say binom(10000, 2), but can be easily adapted to use BigInt.

def binom(n: Int, k: Int): BigInt = {
  require(0 <= k && k <= n)
  def binomtail(nIter: Int, kIter: Int, ac: BigInt): BigInt = {
    if (kIter > k) ac
    else binomtail(nIter + 1, kIter + 1, (nIter * ac) / kIter)
  if (k == 0 || k == n) 1
  else binomtail(n - k + 1, 1, BigInt(1))

