In researching how to do Memoization in Scala, I\'ve found some code I didn\'t grok. I\'ve tried to look this particular \"thing\" up, but don\'t know by what to call it; i.
This answer is a synthesis of the partial answers provided by both 0__ and Nicolas Rinaudo.
Summary:
There are many convenient (but also highly intertwined) assumptions being made by the Scala compiler.
extends (A => B)
as synonymous with extends Function1[A, B]
(ScalaDoc for Function1[+T1, -R])apply(x: A): B
must be provided; def apply(x: A): B = cache.getOrElseUpdate(x, f(x))
match
for the code block starting with = Memo {
{}
started in item 3 as a parameter to the Memo case class constructor{}
started in item 3 as PartialFunction[Int, BigInt]
and the compiler uses the "match" code block as the override for the PartialFunction method's apply()
and then provides an additional override for the PartialFunction's method isDefinedAt()
. Details:
The first code block defining the case class Memo can be written more verbosely as such:
case class Memo[A,B](f: A => B) extends Function1[A, B] { //replaced (A => B) with what it's translated to mean by the Scala compiler
private val cache = mutable.Map.empty[A, B]
def apply(x: A): B = cache.getOrElseUpdate(x, f(x)) //concrete implementation of unimplemented method defined in parent class, Function1
}
The second code block defining the val fibanocci can be written more verbosely as such:
lazy val fibonacci: Memo[Int, BigInt] = {
Memo.apply(
new PartialFunction[Int, BigInt] {
override def apply(x: Int): BigInt = {
x match {
case 0 => 0
case 1 => 1
case n => fibonacci(n-1) + fibonacci(n-2)
}
}
override def isDefinedAt(x: Int): Boolean = true
}
)
}
Had to add lazy
to the second code block's val in order to deal with a self-referential problem in the line case n => fibonacci(n-1) + fibonacci(n-2)
.
And finally, an example usage of fibonacci is:
val x:BigInt = fibonacci(20) //returns 6765 (almost instantly)
A => B
is short for Function1[A, B]
, so your Memo
extends a function from A
to B
, most prominently defined through method apply(x: A): B
which must be defined.
Because of the "infix" notation, you need to put parentheses around the type, i.e. (A => B)
. You could also write
case class Memo[A, B](f: A => B) extends Function1[A, B] ...
or
case class Memo[A, B](f: Function1[A, B]) extends Function1[A, B] ...
To complete 0_'s answer, fibonacci
is being instanciated through the apply method of Memo
's companion object, generated automatically by the compiler since Memo
is a case class.
This means that the following code is generated for you:
object Memo {
def apply[A, B](f: A => B): Memo[A, B] = new Memo(f)
}
Scala has special handling for the apply
method: its name needs not be typed when calling it. The two following calls are strictly equivalent:
Memo((a: Int) => a * 2)
Memo.apply((a: Int) => a * 2)
The case
block is known as pattern matching. Under the hood, it generates a partial function - that is, a function that is defined for some of its input parameters, but not necessarily all of them. I'll not go in the details of partial functions as it's beside the point (this is a memo I wrote to myself on that topic, if you're keen), but what it essentially means here is that the case
block is in fact an instance of PartialFunction.
If you follow that link, you'll see that PartialFunction
extends Function1 - which is the expected argument of Memo.apply
.
So what that bit of code actually means, once desugared (if that's a word), is:
lazy val fibonacci: Memo[Int, BigInt] = Memo.apply(new PartialFunction[Int, BigInt] {
override def apply(v: Int): Int =
if(v == 0) 0
else if(v == 1) 1
else fibonacci(v - 1) + fibonacci(v - 2)
override isDefinedAt(v: Int) = true
})
Note that I've vastly simplified the way the pattern matching is handled, but I thought that starting a discussion about unapply
and unapplySeq
would be off topic and confusing.
I am the original author of doing memoization this way. You can see some sample usages in that same file. It also works really well when you want to memoize on multiple arguments too because of the way Scala unrolls tuples:
/**
* @return memoized function to calculate C(n,r)
* see http://mathworld.wolfram.com/BinomialCoefficient.html
*/
val c: Memo[(Int, Int), BigInt] = Memo {
case (_, 0) => 1
case (n, r) if r > n/2 => c(n, n-r)
case (n, r) => c(n-1, r-1) + c(n-1, r)
}
// note how I can invoke a memoized function on multiple args too
val x = c(10, 3)
One more word about this extends (A => B)
: the extends
here is not required, but necessary if the instances of Memo
are to be used in higher order functions or situations alike.
Without this extends (A => B)
, it's totally fine if you use the Memo
instance fibonacci
in just method calls.
case class Memo[A,B](f: A => B) {
private val cache = scala.collection.mutable.Map.empty[A, B]
def apply(x: A):B = cache getOrElseUpdate (x, f(x))
}
val fibonacci: Memo[Int, BigInt] = Memo {
case 0 => 0
case 1 => 1
case n => fibonacci(n-1) + fibonacci(n-2)
}
For example:
Scala> fibonacci(30)
res1: BigInt = 832040
But when you want to use it in higher order functions, you'd have a type mismatch error.
Scala> Range(1, 10).map(fibonacci)
<console>:11: error: type mismatch;
found : Memo[Int,BigInt]
required: Int => ?
Range(1, 10).map(fibonacci)
^
So the extends
here only helps to ID the instance fibonacci
to others that it has an apply
method and thus can do some jobs.