I have been trying to do memorisation in Julia for the Fibonacci function. This is what I came up with.
The original unmodified code (for control purposes)
The simplest way to do it is to use get!
const fibmem = Dict{Int,Int}()
function fib(n)
get!(fibmem, n) do
n < 3 ? 1 : fib(n-1) + fib(n-2)
end
end
Note the const specifier outside fibmem
. This avoids the need for global
, and will make the code faster as it allows the compiler to use type inference within fib
.
Since the arguments to the function are integers, you can use a simple array, which will be faster than a Dict
(make sure you use BigInt
s in the cache for large arguments to avoid overflow):
function fib(n, cache=sizehint!(BigInt[0,1],n))
n < length(cache) && return cache[n+1]
f = fib(n-1,cache) + fib(n-2,cache)
push!(cache,f)
return f
end
As pointed out in the comments, the Memoize.jl package is certainly the easiest option. This requires you to mark the method at definition time.
By far the most powerful approach, however, is to use Cassette.jl, which lets you add memoization to pre-existing functions, e.g.
fib(x) = x < 3 ? 1 : fib(x-2) + fib(x-1)
using Cassette
Cassette.@context MemoizeCtx
function Cassette.overdub(ctx::MemoizeCtx, ::typeof(fib), x)
get(ctx.metadata, x) do
result = recurse(ctx, fib, x)
ctx.metadata[x] = result
return result
end
end
A little bit of a description of what is going on:
MemoizeCtx
is the Cassette "context" which we are definingrecurse(...)
tells Cassette to call the function, but ignore the top level overload
.Now we can run the function with memoization:
Cassette.overdub(MemoizeCtx(metadata=Dict{Int,Int}()), fib, 80)
Now what's even cooler is that we can take an existing function which calls fib
, and memoize the call to fib
inside that function:
function foo()
println("calling fib")
@show fib(80)
println("done.")
end
Cassette.overdub(MemoizeCtx(metadata=Dict{Int,Int}()), foo)
(Cassette is still pretty hard on the compiler, so this may take a while to run the first time, but will be fast after that).