问题
For the given imbalanced data , I have created a different pipelines for standardization & one hot encoding
numeric_transformer = Pipeline(steps = [('scaler', StandardScaler())])
categorical_transformer = Pipeline(steps=['ohe', OneHotCategoricalEncoder()])
After that a column transformer keeping the above pipelines in one
from sklearn.compose import ColumnTransformer
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer,categorical_features)]
The final pipeline is as below
smt = SMOTE(random_state=42)
rf = pl1([('preprocessor', preprocessor),('smote',smt),
('classifier', RandomForestClassifier())])
I am doing the pipeline fit on imbalanced data so i have included the SMOTE technique along with the pre-processing and classifier. As it is imbalanced I want to check for the recall score.
Is the correct way as shown in the code below? I am getting recall around 0.98 which can cause the model to overfit. Any suggestions if I am making any mistake?
scores = cross_val_score(rf, X, y, cv=5,scoring="recall")
回答1:
The important concern in imbalanced settings is to ensure that enough members of the minority class will be present in each CV fold; thus, it would seem advisable to enforce that using StratifiedKFold
, i.e.:
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5)
scores = cross_val_score(rf, X, y, cv=skf, scoring="recall")
Nevertheless, it turns out that even when using the cross_val_score
as you do (i.e. simply with cv=5
), scikit-learn takes care of it and engages a stratified CV indeed; from the docs:
cv : int, cross-validation generator or an iterable, default=None
None, to use the default 5-fold cross validation,
int, to specify the number of folds in a
(Stratified)KFold
.For int/None inputs, if the estimator is a classifier and
y
is either binary or multiclass,StratifiedKFold
is used. In all other cases,KFold
is used.
So, using your code as is:
scores = cross_val_score(rf, X, y, cv=5, scoring="recall")
is absolutely fine indeed.
来源:https://stackoverflow.com/questions/62308095/correct-way-to-do-cross-validation-in-a-pipeline-with-imbalanced-data