Add trend line to pandas

后端 未结 2 790
自闭症患者
自闭症患者 2021-01-02 09:54

I have time-series data, as followed:

                  emplvl
date                    
2003-01-01  10955.000000
2003-04-01  11090.333333
2003-07-01  11157.0         


        
相关标签:
2条回答
  • 2021-01-02 10:21

    Here's a quick example on how to do this using pandas.ols:

    import matplotlib.pyplot as plt
    import pandas as pd
    
    x = pd.Series(np.arange(50))
    y = pd.Series(10 + (2 * x + np.random.randint(-5, + 5, 50)))
    regression = pd.ols(y=y, x=x)
    regression.summary
    
    -------------------------Summary of Regression Analysis-------------------------
    
    Formula: Y ~ <x> + <intercept>
    
    Number of Observations:         50
    Number of Degrees of Freedom:   2
    
    R-squared:         0.9913
    Adj R-squared:     0.9911
    
    Rmse:              2.7625
    
    F-stat (1, 48):  5465.1446, p-value:     0.0000
    
    Degrees of Freedom: model 1, resid 48
    
    -----------------------Summary of Estimated Coefficients------------------------
          Variable       Coef    Std Err     t-stat    p-value    CI 2.5%   CI 97.5%
    --------------------------------------------------------------------------------
                 x     2.0013     0.0271      73.93     0.0000     1.9483     2.0544
         intercept     9.5271     0.7698      12.38     0.0000     8.0183    11.0358
    ---------------------------------End of Summary---------------------------------
    
    trend = regression.predict(beta=regression.beta, x=x[20:]) # slicing to only use last 30 points
    data = pd.DataFrame(index=x, data={'y': y, 'trend': trend})
    data.plot() # add kwargs for title and other layout/design aspects
    plt.show() # or plt.gcf().savefig(path)
    

    0 讨论(0)
  • 2021-01-02 10:32

    In general you should create your matplotlib figure and axes object ahead of time, and explicitly plot the dataframe on that:

    from matplotlib import pyplot
    import pandas
    import statsmodels.api as sm
    
    df = pandas.read_csv(...)
    
    fig, ax = pyplot.subplots()
    df.plot(x='xcol', y='ycol', ax=ax)
    

    Then you still have that axes object around to use directly to plot your line:

    model = sm.formula.ols(formula='ycol ~ xcol', data=df)
    res = model.fit()
    df.assign(fit=res.fittedvalues).plot(x='xcol', y='fit', ax=ax)
    
    0 讨论(0)
提交回复
热议问题