问题
What I am trying to do?
I am trying to use StratifiedKFold()
in GridSearchCV()
.
Then, what does confuse me?
When we use K Fold Cross Validation, we just pass the number of CV inside GridSearchCV()
like the following.
grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=5, scoring='f1', return_train_score=True, n_jobs=2)
Then, when I will need to use StratifiedKFold()
, I think the procedure should remain same. That is, set the number of splits only - StratifiedKFold(n_splits=5)
to cv
.
grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=StratifiedKFold(n_splits=5), scoring='f1', return_train_score=True, n_jobs=2)
But this answer says
whatever the cross validation strategy used, all that is needed is to provide the generator using the function split, as suggested:
kfolds = StratifiedKFold(5) clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain)) clf.fit(xtrain, ytrain)
Moreover, one of the answers of this question also suggest to do this. This means, they suggest to call split function :StratifiedKFold(n_splits=5).split(xtrain,ytrain)
during using GridSearchCV()
. But, I have found that calling split()
and without calling split()
give me the same f1 score.
Hence, my questions
I do not understand why do we need to call
split()
function during Stratified K Fold as we do not need to do such type of things during K Fold CV.If
split()
function is called, howGridSearchCV()
will work asSplit()
function returns training and testing data set indices? That is, I want to know howGridSearchCV()
will use those indices?
回答1:
Basically GridSearchCV is clever and can take multiple options for that cv parameter - a number, an iterator of split indices or an object with a split function. You can look at the code here, copied below.
cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
if (classifier and (y is not None) and
(type_of_target(y) in ('binary', 'multiclass'))):
return StratifiedKFold(cv)
else:
return KFold(cv)
if not hasattr(cv, 'split') or isinstance(cv, str):
if not isinstance(cv, Iterable) or isinstance(cv, str):
raise ValueError("Expected cv as an integer, cross-validation "
"object (from sklearn.model_selection) "
"or an iterable. Got %s." % cv)
return _CVIterableWrapper(cv)
return cv # New style cv objects are passed without any modification
Basically if you don't pass anything, it uses a KFold with 5. It's also clever enough to automatically use StratifedKFold, if it's a classification problem and the target is binary/multiclass.
If you pass an object with a split function, it just uses that. And if you don't pass any of them, but pass an iterable, it assumes that is an iterable of the split indices and wraps that up for you.
So in your case, assuming it's a classification problem with a binary/multiclass target, all the below will give the exact same results/splits - it does not matter which one you use!
cv=5
cv=StratifiedKFold(5)
cv=StratifiedKFold(5).split(xtrain,ytrain)
来源:https://stackoverflow.com/questions/62174112/why-we-should-call-split-function-during-passing-stratifiedkfold-as-a-parame