问题
I would like to write a curve-fitting script that allows me to fix parameters of a function of the form:
def func(x, *p):
assert len(p) % 2 == 0
fval = 0
for j in xrange(0, len(p), 2):
fval += p[j]*np.exp(-p[j+1]*t)
return fval
For example, let's say I want p = [p1, p2, p3, p4], and I want p2 and p3 to be constant A and B (going from a 4-parameter fit to a 2-parameter fit). I understand that functools.partial doesn't let me do this which is why I want to write my own wrapper. But I am having a bit of trouble doing so. This is what I have so far:
def fix_params(f, t, pars, fix_pars):
# fix_pars = ((ind1, A), (ind2, B))
new_pars = [None]*(len(pars) + len(fix_pars))
for ind, fix in fix_pars:
new_pars[ind] = fix
for par in pars:
for j, npar in enumerate(new_pars):
if npar == None:
new_pars[j] = par
break
assert None not in new_pars
return f(t, *new_pars)
The problem with this I think is that, scipy.optimize.curve_fit won't work well with a function passed through this kind of wrapper. How should I get around this?
回答1:
Sounds like what you want to do is currying? In Python, you can do this with inner functions.
Example:
def foo(x):
def bar(y):
return x + y
return bar
bar = foo(3)
print(type(bar)) # a function (of one variable with the other fixed to 3)
print(bar(8)) # 11
bar = foo(9)
print(bar(8)) # 17
In this way we can fix x in the function x + y. You can also put this into a decorator.
Here's a blog post someone wrote on doing this: https://mtomassoli.wordpress.com/2012/03/18/currying-in-python/
Regarding what will play nice with external libraries, the function foo here will return a function. In Python functions are first-class objects. So anything you give the returned function to will just see it as a function.
回答2:
So I think I have something workable. Maybe there is a way to improve on this.
Here is my code (without all the exception handling):
def func(x, *p):
fval = 0
for j in xrange(0, len(p), 2):
fval += p[j]*np.exp(-p[j+1]*x)
return fval
def fix_params(f, fix_pars):
# fix_pars = ((1, A), (2, B))
def new_func(x, *pars):
new_pars = [None]*(len(pars) + len(fix_pars))
for j, fp in fix_pars:
new_pars[j] = fp
for par in pars:
for j, npar in enumerate(new_pars):
if npar is None:
new_pars[j] = par
break
return f(x, *new_pars)
return new_func
p1 = [1, 0.5, 0.1, 1.2]
pfix = ((1, 0.5), (2, 0.1))
p2 = [1, 1.2]
new_func = fix_params(func, pfix)
x = np.arange(10)
dat1 = func(x, *p1)
dat2 = new_func(x, *p2)
if (dat1==dat2).all()
print "ALL GOOD"
来源:https://stackoverflow.com/questions/54382612/how-to-write-a-wrapper-to-fix-arbitrary-parameters-in-a-function