Using statsmodel estimations with scikit-learn cross validation, is it possible?

后端 未结 4 2045
無奈伤痛
無奈伤痛 2021-01-31 19:05

I posted this question to Cross Validated forum and later realized may be this would find appropriate audience in stackoverlfow instead.

I am looking for a way I can use

4条回答
  •  谎友^
    谎友^ (楼主)
    2021-01-31 19:17

    For reference purpose, if you use the statsmodels formula API and/or use the fit_regularized method, you can modify @David Dale's wrapper class in this way.

    import pandas as pd
    from sklearn.base import BaseEstimator, RegressorMixin
    from statsmodels.formula.api import glm as glm_sm
    
    # This is an example wrapper for statsmodels GLM
    class SMWrapper(BaseEstimator, RegressorMixin):
        def __init__(self, family, formula, alpha, L1_wt):
            self.family = family
            self.formula = formula
            self.alpha = alpha
            self.L1_wt = L1_wt
            self.model = None
            self.result = None
        def fit(self, X, y):
            data = pd.concat([pd.DataFrame(X), pd.Series(y)], axis=1)
            data.columns = X.columns.tolist() + ['y']
            self.model = glm_sm(self.formula, data, family=self.family)
            self.result = self.model.fit_regularized(alpha=self.alpha, L1_wt=self.L1_wt, refit=True)
            return self.result
        def predict(self, X):
            return self.result.predict(X)
    

提交回复
热议问题