问题
Im trying to follow and re-use a piece of code (with my own data) suggested by someone named @ThePredator (I couldn't comment on that thread since I don't currently have the required reputation of 50). The full code is as follows:
import numpy as np # This is the Numpy module
from scipy.optimize import curve_fit # The module that contains the curve_fit routine
import matplotlib.pyplot as plt # This is the matplotlib module which we use for plotting the result
""" Below is the function that returns the final y according to the conditions """
def fitfunc(x,a1,a2):
y1 = (x**(a1) )[x<xc]
y2 = (x**(a1-a2) )[x>xc]
y3 = (0)[x==xc]
y = np.concatenate((y1,y2,y3))
return y
x = array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460])
y = array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17])
""" In the above code, we have imported 3 modules, namely Numpy, Scipy and matplotlib """
popt,pcov = curve_fit(fitfunc,x,y,p0=(10.0,1.0)) #here we provide random initial parameters a1,a2
a1 = popt[0]
a2 = popt[1]
residuals = y - fitfunc(x,a1,a2)
chi-sq = sum( (residuals**2)/fitfunc(x,a1,a2) ) # This is the chi-square for your fitted curve
""" Now if you need to plot, perform the code below """
curvey = fitfunc(x,a1,a2) # This is your y axis fit-line
plt.plot(x, curvey, 'red', label='The best-fit line')
plt.scatter(x,y, c='b',label='The data points')
plt.legend(loc='best')
plt.show()
Im having some problem running this code and the errors I get are as follows:
y3 = (0)[x==xc]
TypeError: 'int' object has no attribute 'getitem'
and also:
xc is undefined
I don't see anything missing in the code (xc shouldn't have to be defined?).
Could the author (@ThePredator) or someone else having knowledge about this please help me identify what i haven't seen.
New version of code:
import numpy as np # This is the Numpy module from scipy.optimize import curve_fit import matplotlib.pyplot as plt def fitfunc(x, a1, a2, xc): if x.all() < xc: y = x**a1 elif x.all() > xc: y = x**(a1 - a2) * x**a2 else: y = 0 return y xc = 2 x = np.array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460]) y = np.array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17]) popt,pcov = curve_fit(fitfunc,x,y,p0=(1.0,1.0)) a1 = popt[0] a2 = popt[1] residuals = y - fitfunc(x, a1, a2, xc) chisq = sum((residuals**2)/fitfunc(x, a1, a2, xc)) curvey = [fitfunc(val, a1, a2, xc) for val in x] # y-axis fit-line plt.plot(x, curvey, 'red', label='The best-fit line') plt.scatter(x,y, c='b',label='The data points') plt.legend(loc='best') plt.show()
回答1:
There are multiple errors/typos in your code.
1) You cannot use -
in your variable names in Python (chi-square should be chi_square
for example)
2) You should from numpy import array
or replace array
with np.array
. Currently the name array
is not defined.
3) xc
is not defined, you should set it before calling fitfunc()
.
4) y3 = (0)[x==xc]
is not valid, should be (I think) y3 = np.zeros(len(x))[x==xc]
or y3 = np.zeros(np.sum(x==xc))
Your use of fit_function() is wrong, because it changes the order of the images. What you want is:
def fit_function(x, a1, a2, xc):
if x < xc:
y = x**a1
elif x > xc:
y = x**(a1 - a2) * x**a2
else:
y = 0
return y
xc = 2 #or any value you want
curvey = [fit_function(val, a1, a2, xc) for val in x]
回答2:
Hi Do the following to define your function, and it will solve. x is an array (or list) and it should return y as an array (or list). And then you can use it in curvefit.
def fit_function(x, a1, a2, xc):
y = []
for xx in x:
if xx<xc:
y.append(x**a1)
elif xx>xc:
y.append(x**(a1 - a2) * x**a2)
else:
y.append(0.0)
return y
来源:https://stackoverflow.com/questions/32271090/curve-fitting-with-broken-power-law-in-python