scipy.interpolate.UnivariateSpline not smoothing regardless of parameters

后端 未结 4 630
死守一世寂寞
死守一世寂寞 2021-01-01 04:31

I\'m having trouble getting scipy.interpolate.UnivariateSpline to use any smoothing when interpolating. Based on the function\'s page as well as some previous posts, I beli

4条回答
  •  迷失自我
    2021-01-01 05:06

    I had trouble getting BigChef's answer running, here is a variation that works on python 3.6:

    # Imports
    import pylab
    import scipy
    import sklearn.cluster
    
    # Set up original data - note that it's monotonically increasing by X value!
    data = {}
    data['original'] = {}
    data['original']['x'] = [0, 5024.2059124920379, 7933.1645067836089, 7990.4664106277542, 9879.9717114947653, 13738.60563208926, 15113.277958924193]
    data['original']['y'] = [0.0, 3072.5653360000988, 5477.2689107965398, 5851.6866463790966, 6056.3852496014106, 7895.2332350173638, 9154.2956175610598]
    
    # Cluster data, sort it and and save
    import numpy
    inputNumpy = numpy.array([[data['original']['x'][i], data['original']['y'][i]] for i in range(0, len(data['original']['x']))])
    meanShift = sklearn.cluster.MeanShift()
    meanShift.fit(inputNumpy)
    clusteredData = [[pair[0], pair[1]] for pair in meanShift.cluster_centers_]
    
    clusteredData.sort(key=lambda li: li[0])
    data['clustered'] = {}
    data['clustered']['x'] = [pair[0] for pair in clusteredData]
    data['clustered']['y'] = [pair[1] for pair in clusteredData]
    
    # Build a spline using the clustered data and predict
    mySpline = scipy.interpolate.UnivariateSpline(x=data['clustered']['x'], y=data['clustered']['y'], k=1)
    xi = range(0, int(round(max(data['original']['x']), -3)) + 3000, 20)
    yi = mySpline(xi)
    
    # Plot the datapoints
    pylab.plot(data['clustered']['x'], data['clustered']['y'], "D", label="Datapoints (%s)" % 'clustered')
    pylab.plot(xi, yi, label="Predicted (%s)" %  'clustered')
    pylab.plot(data['original']['x'], data['original']['y'], "o", label="Datapoints (%s)" % 'original')
    
    # Show the plot
    pylab.grid(True)
    pylab.xticks(rotation=45)
    pylab.show()
    

提交回复
热议问题