How to overplot a line on a scatter plot in python?

前端 未结 7 1895
轻奢々
轻奢々 2020-11-28 02:18

I have two vectors of data and I\'ve put them into matplotlib.scatter(). Now I\'d like to over plot a linear fit to these data. How would I do this? I\'ve tried

相关标签:
7条回答
  • 2020-11-28 02:48
    import numpy as np
    from numpy.polynomial.polynomial import polyfit
    import matplotlib.pyplot as plt
    
    # Sample data
    x = np.arange(10)
    y = 5 * x + 10
    
    # Fit with polyfit
    b, m = polyfit(x, y, 1)
    
    plt.plot(x, y, '.')
    plt.plot(x, b + m * x, '-')
    plt.show()
    

    enter image description here

    0 讨论(0)
  • 2020-11-28 02:52

    I'm partial to scikits.statsmodels. Here an example:

    import statsmodels.api as sm
    import numpy as np
    import matplotlib.pyplot as plt
    
    X = np.random.rand(100)
    Y = X + np.random.rand(100)*0.1
    
    results = sm.OLS(Y,sm.add_constant(X)).fit()
    
    print results.summary()
    
    plt.scatter(X,Y)
    
    X_plot = np.linspace(0,1,100)
    plt.plot(X_plot, X_plot*results.params[0] + results.params[1])
    
    plt.show()
    

    The only tricky part is sm.add_constant(X) which adds a columns of ones to X in order to get an intercept term.

         Summary of Regression Results
    =======================================
    | Dependent Variable:            ['y']|
    | Model:                           OLS|
    | Method:                Least Squares|
    | Date:               Sat, 28 Sep 2013|
    | Time:                       09:22:59|
    | # obs:                         100.0|
    | Df residuals:                   98.0|
    | Df model:                        1.0|
    ==============================================================================
    |                   coefficient     std. error    t-statistic          prob. |
    ------------------------------------------------------------------------------
    | x1                      1.007       0.008466       118.9032         0.0000 |
    | const                 0.05165       0.005138        10.0515         0.0000 |
    ==============================================================================
    |                          Models stats                      Residual stats  |
    ------------------------------------------------------------------------------
    | R-squared:                     0.9931   Durbin-Watson:              1.484  |
    | Adjusted R-squared:            0.9930   Omnibus:                    12.16  |
    | F-statistic:                1.414e+04   Prob(Omnibus):           0.002294  |
    | Prob (F-statistic):        9.137e-108   JB:                        0.6818  |
    | Log likelihood:                 223.8   Prob(JB):                  0.7111  |
    | AIC criterion:                 -443.7   Skew:                     -0.2064  |
    | BIC criterion:                 -438.5   Kurtosis:                   2.048  |
    ------------------------------------------------------------------------------
    

    example plot

    0 讨论(0)
  • 2020-11-28 02:58

    I like Seaborn's regplot or lmplot for this:

    0 讨论(0)
  • 2020-11-28 03:03

    A one-line version of this excellent answer to plot the line of best fit is:

    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))
    

    Using np.unique(x) instead of x handles the case where x isn't sorted or has duplicate values.

    The call to poly1d is an alternative to writing out m*x + b like in this other excellent answer.

    0 讨论(0)
  • 2020-11-28 03:08

    Another way to do it, using axes.get_xlim():

    import matplotlib.pyplot as plt
    import numpy as np
    
    def scatter_plot_with_correlation_line(x, y, graph_filepath):
        '''
        http://stackoverflow.com/a/34571821/395857
        x does not have to be ordered.
        '''
        # Create scatter plot
        plt.scatter(x, y)
    
        # Add correlation line
        axes = plt.gca()
        m, b = np.polyfit(x, y, 1)
        X_plot = np.linspace(axes.get_xlim()[0],axes.get_xlim()[1],100)
        plt.plot(X_plot, m*X_plot + b, '-')
    
        # Save figure
        plt.savefig(graph_filepath, dpi=300, format='png', bbox_inches='tight')
    
    def main():
        # Data
        x = np.random.rand(100)
        y = x + np.random.rand(100)*0.1
    
        # Plot
        scatter_plot_with_correlation_line(x, y, 'scatter_plot.png')
    
    if __name__ == "__main__":
        main()
        #cProfile.run('main()') # if you want to do some profiling
    

    0 讨论(0)
  • 2020-11-28 03:13
    plt.plot(X_plot, X_plot*results.params[0] + results.params[1])
    

    versus

    plt.plot(X_plot, X_plot*results.params[1] + results.params[0])
    
    0 讨论(0)
提交回复
热议问题