Using statsmodel estimations with scikit-learn cross validation, is it possible?

后端 未结 4 2041
無奈伤痛
無奈伤痛 2021-01-31 19:05

I posted this question to Cross Validated forum and later realized may be this would find appropriate audience in stackoverlfow instead.

I am looking for a way I can use

4条回答
  •  鱼传尺愫
    2021-01-31 19:23

    Indeed, you cannot use cross_val_score directly on statsmodels objects, because of different interface: in statsmodels

    • training data is passed directly into the constructor
    • a separate object contains the result of model estimation

    However, you can write a simple wrapper to make statsmodels objects look like sklearn estimators:

    import statsmodels.api as sm
    from sklearn.base import BaseEstimator, RegressorMixin
    
    class SMWrapper(BaseEstimator, RegressorMixin):
        """ A universal sklearn-style wrapper for statsmodels regressors """
        def __init__(self, model_class, fit_intercept=True):
            self.model_class = model_class
            self.fit_intercept = fit_intercept
        def fit(self, X, y):
            if self.fit_intercept:
                X = sm.add_constant(X)
            self.model_ = self.model_class(y, X)
            self.results_ = self.model_.fit()
        def predict(self, X):
            if self.fit_intercept:
                X = sm.add_constant(X)
            return self.results_.predict(X)
    

    This class contains correct fit and predict methods, and can be used with sklearn, e.g. cross-validated or included into a pipeline. Like here:

    from sklearn.datasets import make_regression
    from sklearn.model_selection import cross_val_score
    from sklearn.linear_model import LinearRegression
    
    X, y = make_regression(random_state=1, n_samples=300, noise=100)
    
    print(cross_val_score(SMWrapper(sm.OLS), X, y, scoring='r2'))
    print(cross_val_score(LinearRegression(), X, y, scoring='r2'))
    

    You can see that the output of two models is identical, because they are both OLS models, cross-validated in the same way.

    [0.28592315 0.37367557 0.47972639]
    [0.28592315 0.37367557 0.47972639]
    

提交回复
热议问题