问题
I now have some discrete points, and I interpolated it using the scipy.interpolate.splprep () function (B-spline interpolation) to get a satisfactory smooth curve. Here's the code (draw on the answer to another question) and the result I got.
import numpy as np
from scipy import interpolate
from matplotlib import pyplot as plt
# x and y are points sampled randomly
x = sampledx
y = sampledy
# append the starting x,y coordinates
x = np.r_[x, x[0]]
y = np.r_[y, y[0]]
# fit splines to x=f(u) and y=g(u), treating both as periodic. also note that s=0
# is needed in order to force the spline fit to pass through all the input points.
tck, u = interpolate.splprep([x, y], s=0, per=True)
# evaluate the spline fits for 1000 evenly spaced distance values
xi, yi = interpolate.splev(np.linspace(0, 1, 1000), tck)
# plot the result
fig, ax = plt.subplots(figsize=(12, 12))
ax.plot(x, y, 'or')
ax.plot(xi, yi, '-b')
obtained curve
As far as I know, the function model obtained by cubic spline interpolation is a series of polynomials. Now I want to take out this function model, I try to print out the contents of tck.
[array([-0.30733587, -0.28200105, -0.22446703, 0. , 0.03802363,
0.07911629, 0.09557235, 0.15790186, 0.20199024, 0.24140097,
0.26977782, 0.31416052, 0.35118666, 0.42856196, 0.45166591,
0.49503978, 0.51375395, 0.56799754, 0.59262884, 0.61845984,
0.65603571, 0.69266413, 0.71799895, 0.77553297, 1. ,
1.03802363, 1.07911629, 1.09557235]),
[array([229.12471144, -98.86968613, 50.15238681, 83.22909902,
88.9466649 , 103.43169139, 158.24339347, 200.28605252,
245.21725764, 291.11861604, 356.23057282, 404.75955996,
429.18100345, 435.79417275, 430.58694659, 402.28422935,
381.19094487, 360.28746542, 316.79933633, 271.50003508,
242.72352701, 229.12471144, -98.86968613, 50.15238681]),
array([-77.44508113, 184.01906954, 197.43235399, 226.25242057,
275.95919475, 329.12264277, 360.20146464, 378.28519513,
391.18454729, 390.47825093, 380.06668473, 339.92688063,
285.65908782, 250.27639394, 201.82803336, 168.81117187,
133.96870427, 94.65595445, 126.9811583 , 121.02433492,
78.83626675, -77.44508113, 184.01906954, 197.43235399])],
3]
After consulting the relevant documents, I learned that the first array is a list of knots, the second and third arrays are lists of coefficients, and the last single number is degree. If I got it right, the function model will be composed of 7 polynomials of which the max dimension of x is 3. How could I extract a function model (polynomials) based on these parameters ? Thanks a lot.
回答1:
The tck
returned by interpolate.splprep
consists 3 parts:
tck[0]
: the 'knots' for the b-splines (this are values for the parameteru
)tck[1]
: x and y coordinates of the relocated control pointstck[2]
: the degree of the b-splines (3 for these cubic b-splines)
interpolate.splprep
also outputs a list of u ticks
. These are the values of u
for which the b-spline is at each of the points-to-be-interpolated. These are marked with black lines on the colorbar.
A set of b-spline basis functions can be calculated depending on the knots. There will be one basis function for each control point (24 in your example).
To draw the curve, u
needs to vary between 0 and 1. This is the np.linspace(0, 1, 1000)
in your example code. For each of the u values, each pair of (basis-function(u), x-value)
is multiplied together and the sum over all pairs is taken. The same happens for y
.
Sympy's bspline_basis_set can be used to show how these functions look like.
Here is an example with just 4 points, as you'll notice the functions quickly become quite complex.
import numpy as np
from scipy import interpolate
from matplotlib import pyplot as plt
# x and y for a simple quadrangle
x = [0, 1, 40, 45]
y = [0, 22, 35, 7]
# append the starting x,y coordinates
x = np.r_[x, x[0]]
y = np.r_[y, y[0]]
# fit splines to x=f(u) and y=g(u), treating both as periodic. also note that s=0
# is needed in order to force the spline fit to pass through all the input points.
tck, u_ticks = interpolate.splprep([x, y], s=0, per=True)
# evaluate the spline fits for 1000 evenly spaced distance values
xi, yi = interpolate.splev(np.linspace(0, 1, 1000), tck)
# plot the result
fig, ax = plt.subplots(figsize=(12, 12))
ax.plot(x, y, 'Pk', ms=10, label='Points to interpolate')
ax.plot(xi, yi, '-b', lw=1, label='Interpolating spline (splev)', zorder=0)
ax.plot(tck[1][0], tck[1][1], 'om', ls=':', label='Calculated control points')
from sympy import lambdify, bspline_basis_set
from sympy.abc import u
basis = bspline_basis_set(tck[2], tck[0], u)
for i, b in enumerate(basis):
print(f"Basis {i} :", b)
# convert the basis functions to numpy so they can be evaluated quicker
np_basis = [lambdify(u, b, modules=['numpy']) for b in basis]
tck_x = tck[1][0]
tck_y = tck[1][1]
us = np.linspace(0, 1, 100)
xs = [sum([xi * bi(u_val) for xi, bi in zip(tck_x, np_basis)]) for u_val in us]
ys = [sum([yi * bi(u_val) for yi, bi in zip(tck_y, np_basis)]) for u_val in us]
plt.scatter(xs, ys, c=us, s=40, marker='o', cmap='tab10')
plt.legend()
cbar = plt.colorbar(label='u values')
for t in u_ticks:
# mark the position of the u_ticks at the color bar
cbar.ax.axhline(t, lw='3', color='black', clip_on=False)
plt.show()
Output:
Basis 0 : Piecewise((7.83358627878421*u**3 + 19.7262258572059*u**2 + 16.5579328428993*u + 4.63283654316489, (u >= -0.83938676170286) & (u <= -0.539571441177499)), (-34.7262442279844*u**3 - 49.1659813912158*u**2 - 20.6143347080305*u - 2.05286144826537, (u >= -0.539571441177499) & (u <= -0.332135154281002)), (23.3437491730212*u**3 + 8.69527726080352*u**2 - 1.39657663874914*u + 0.0747695654932114, (u >= -0.332135154281002) & (u <= 0)), (-18.0459953633398*u**3 + 8.69527726080352*u**2 - 1.39657663874914*u + 0.0747695654932114, (u >= 0) & (u <= 0.16061323829714)), (0, True))
Basis 1 : Piecewise((12.7600892248919*u**3 + 20.6549391978852*u**2 + 11.1448153104365*u + 2.00447468623643, (u >= -0.539571441177499) & (u <= -0.332135154281002)), (-24.4055001260175*u**3 - 16.3770570611408*u**2 - 1.15481248038858*u + 0.642761761601563, (u >= -0.332135154281002) & (u <= 0)), (51.0502963670014*u**3 - 16.3770570611408*u**2 - 1.15481248038858*u + 0.642761761601563, (u >= 0) & (u <= 0.16061323829714)), (-9.14007459775806*u**3 + 12.6250541237277*u**2 - 5.81293547524402*u + 0.892147167798265, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (0, True))
Basis 2 : Piecewise((7.70949185527263*u**3 + 7.68177980033731*u**2 + 2.55138911913772*u + 0.282468672905225, (u >= -0.332135154281002) & (u <= 0)), (-53.251633917268*u**3 + 7.68177980033731*u**2 + 2.55138911913772*u + 0.282468672905225, (u >= 0) & (u <= 0.16061323829714)), (29.8321355272912*u**3 - 32.3512799809336*u**2 + 8.98122848955063*u - 0.0617704347655956, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (-14.2299460617349*u**3 + 28.5110421933306*u**2 - 19.0415227957366*u + 4.2390545614098, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (0, True))
Basis 3 : Piecewise((20.2473329136064*u**3, (u >= 0) & (u <= 0.16061323829714)), (-28.5256472083174*u**3 + 23.5007588363526*u**2 - 3.77453297914672*u + 0.202079988280036, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (36.1961010648274*u**3 - 65.8984650092776*u**2 + 37.387422815947*u - 6.1153000067368, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (-6.64774090227629*u**3 + 19.9432227068289*u**2 - 19.9432227068289*u + 6.64774090227629, (u >= 0.667864845718998) & (u <= 1.0)), (0, True))
Basis 4 : Piecewise((7.83358627878421*u**3 - 3.77453297914672*u**2 + 0.606239964840107*u - 0.0324567213127046, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (-34.7262442279844*u**3 + 55.0127512927375*u**2 - 26.4611046095522*u + 4.1217360965338, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (23.3437491730212*u**3 - 61.3359702582601*u**2 + 51.2441163587074*u - 13.1771257079753, (u >= 0.667864845718998) & (u <= 1.0)), (-18.0459953633398*u**3 + 62.8332633508229*u**2 - 72.9251172503755*u + 28.2126188283857, (u >= 1.0) & (u <= 1.16061323829714)), (0, True))
Basis 5 : Piecewise((12.7600892248919*u**3 - 17.6253284767905*u**2 + 8.11520458934184*u - 1.2454906512068, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (-24.4055001260175*u**3 + 56.8394433169118*u**2 - 41.6171987361595*u + 9.82601730686685, (u >= 0.667864845718998) & (u <= 1.0)), (51.0502963670015*u**3 - 169.527946162145*u**2 + 184.750190742898*u - 65.6297791861522, (u >= 1.0) & (u <= 1.16061323829714)), (-9.14007459775806*u**3 + 40.0452779170019*u**2 - 58.4832675159736*u + 28.470211364528, (u >= 1.16061323829714) & (u <= 1.4604285588225)), (0, True))
Basis 6 : Piecewise((7.70949185527263*u**3 - 15.4466957654806*u**2 + 10.316305084281*u - 2.29663250116781, (u >= 0.667864845718998) & (u <= 1.0)), (-53.2516339172681*u**3 + 167.436681552142*u**2 - 172.567072233341*u + 58.6644932713729, (u >= 1.0) & (u <= 1.16061323829714)), (29.8321355272912*u**3 - 121.847686562807*u**2 + 163.180195033291*u - 71.226414432541, (u >= 1.16061323829714) & (u <= 1.4604285588225)), (-14.2299460617349*u**3 + 71.2008803785352*u**2 - 118.753445367602*u + 66.0215656122119, (u >= 1.4604285588225) & (u <= 1.667864845719)), (0, True))
Alternatively, as mentioned in this post, sympy has a not-yet-documented function interpolating_spline
that calculates the piecewise functions combined with the x values. (Note that there 'x' is used where we use 'u', and 'y' where we use 'x'. This can be confusing things sometimes ...)
To get this to work with a circular list, 2 extra nodes need to be added at the front and two at the end. So, together with the repeated node added earlier, there are now 9 nodes to represent the 4 original points.
from sympy import interpolating_spline, lambdify
from sympy.abc import u
# ... the same code as above, but replacing the complete sympy part
# use the u_ticks from
us = [u_ticks[-3] - 1, u_ticks[-2] - 1, *u_ticks, u_ticks[1] + 1, u_ticks[2] + 1]
xs = [*x[-3:-1], *x, * x[1:3]]
ys = [*y[-3:-1], *y, * y[1:3]]
interpx = interpolating_spline(tck[2], u, us, xs)
interpy = interpolating_spline(tck[2], u, us, ys)
print(interpx)
print(interpy)
fx = lambdify(u, interpx, modules=['numpy'])
fy = lambdify(u, interpy, modules=['numpy'])
us = np.linspace(0, 1, 100)
plt.scatter(fx(us), fy(us), c=us, s=40, marker='o', cmap='tab10') # label='sympy´s interpolating_spline'
As now the x's are already summed together, there is just one formula for the b-spline for x, and one for y:
# for x:
Piecewise((259.449085976667*u**3 + 332.098590899285*u**2 - 53.8062007647187*u - 8.88178419700125e-16, (u >= -0.332135154281002) & (u <= 0.16061323829714)), (-889.09792969929*u**3 + 885.514157471979*u**2 - 142.692067036006*u + 4.75874894022597, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (-281.671950803575*u**3 + 46.4853533090758*u**2 + 243.620756075287*u - 54.5310698597021, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (976.463184688985*u**3 - 2474.30733116909*u**2 + 1927.16957338388*u - 429.32542690377, (u >= 0.667864845718998) & (u <= 1.16061323829714)))
# for y:
Piecewise((-737.592577045201*u**3 + 194.240200950605*u**2 + 124.804852561614*u + 3.5527136788005e-15, (u >= -0.332135154281002) & (u <= 0.16061323829714)), (-427.62807998269*u**3 + 44.8869960595423*u**2 + 148.792954449223*u - 1.28426890825692, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (1396.06082019756*u**3 - 2474.14836009222*u**2 + 1308.6287731051*u - 179.291447059738, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (-2.71308577093816*u**3 + 328.427396624023*u**2 - 563.113052269992*u + 237.398741416907, (u >= 0.667864845718998) & (u <= 1.16061323829714)))
来源:https://stackoverflow.com/questions/60105444/how-to-extract-the-function-model-polynomials-from-scipy-interpolate-splprep