问题
f(0) = p
f(1) = q
f(2) = r
for n > 2
f(n) = af(n-1) + bf(n-2) + c*f(n-3) + g(n)
where g(n) = n* n* (n+1)
p,q,r,a,b,c are given The question is, How to find the nth term of this series.
Please help me in finding a better solution for this.
I have tried solving this using recursion. But that way is consuming high memory.
回答1:
A better way than recursion would be memoization. You just need to know the last three values for f(n). A solution in pseudocode could look like this:
if n == 0:
return p
else if n == 1:
return q
else if n == 2:
return r
else:
f_n-3 = p
f_n-2 = q
f_n-1 = r
for i from 3 to n:
f_new = a * f_n-1 + b * f_n-2 + c * f_n-3 + g(n)
fn-1 = fn-2
fn-2 = fn-3
fn-3 = f_new
return f_new
This way you don't need to call the method recursively and keep all the values, that were calculated, on the stack, but just keep 4 variables in your memeory.
This should calculate much faster and use much less memory.
回答2:
The problem is that for each call to f
with n > 2
, it results in three extra calls to f
. For example if we call f(5)
, we get the following calls:
- f(5)
- f(4)
- f(3)
- f(2)
- f(1)
- f(0)
- g(3)
- f(2)
- f(1)
- g(4)
- f(3)
- f(2)
- f(1)
- f(0)
- g(3)
- f(2)
- g(5)
We thus make one call f(5)
, one call to f(4)
, two calls to f(3)
, four calls to f(2)
, three calls to f(1)
, and two calls to f(0)
.
Since we make multiple calls to for example f(3)
, it thus means that each time this will cost resources, especially since f(3)
itself will make extra calls.
We can let Python store the result of a function call, and return the result, for example with the lru_cache [Python-doc]. This technique is called memoization:
from functools import lru_cache
def g(n):
return n * n * (n+1)
@lru_cache(maxsize=32)
def f(n):
if n <= 2:
return (p, q, r)[n]
else:
return a*f(n-1) + b*f(n-2) + c*f(n-3) + g(n)
This will result in a call graph like:
- f(5)
- f(4)
- f(3)
- f(2)
- f(1)
- f(0)
- g(3)
- g(4)
- g(5)
So now we will only calculate f(3)
once, the lru_cache
will store it in the cache, and if we call f(3)
a second time, we will never evaluate f(3)
itself, the cache will return the pre-computed value.
The above here can however be optimized, since we each time call f(n-1)
, f(n-2)
and f(n-3)
, we only need to store the last three values, and each time calculate the next value based on the last three values, and shift the variables, like:
def f(n):
if n <= 2:
return (p, q, r)[n]
f3, f2, f1 = p, q, r
for i in range(3, n+1):
f3, f2, f1 = f2, f1, a * f1 + b * f2 + c * f3 + g(i)
return f1
来源:https://stackoverflow.com/questions/56612225/find-nth-term-of-provided-sequence