How to implement SMOTE in cross validation and GridSearchCV

左心房为你撑大大i 提交于 2019-12-03 03:18:39

You need to look at the pipeline object. imbalanced-learn has a Pipeline which extends the scikit-learn Pipeline, to adapt for the fit_sample() and sample() methods in addition to fit_predict(), fit_transform() and predict() methods of scikit-learn.

Have a look at this example here:

For your code, you would want to do this:

from imblearn.pipeline import make_pipeline, Pipeline

smote_enn = SMOTEENN(smote = sm)
clf_rf = RandomForestClassifier(n_estimators=25, random_state=1)

pipeline = make_pipeline(smote_enn, clf_rf)
    OR
pipeline = Pipeline([('smote_enn', smote_enn),
                     ('clf_rf', clf_rf)])

Then you can pass this pipeline object to GridSearchCV, RandomizedSearchCV or other cross validation tools in the scikit-learn as a regular object.

kf = StratifiedKFold(n_splits=n_splits)
random_search = RandomizedSearchCV(pipeline, param_distributions=param_dist,
                                   n_iter=1000, 
                                   cv = kf)

This looks like it would fit the bill http://contrib.scikit-learn.org/imbalanced-learn/stable/generated/imblearn.over_sampling.SMOTE.html

You'll want to create your own transformer (http://scikit-learn.org/stable/modules/generated/sklearn.base.TransformerMixin.html) that upon calling fit returns a balanced data set (presumably the one gotten from StratifiedKFold), but upon calling predict, which is that is going to happen for the test data, calls into SMOTE.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!