How to generate a custom cross-validation generator in scikit-learn?

前端 未结 4 1815
萌比男神i
萌比男神i 2021-01-31 20:08

I have an unbalanced dataset, so I have an strategy for oversampling that I only apply during training of my data. I\'d like to use classes of scikit-learn like GridSearch

4条回答
  •  一个人的身影
    2021-01-31 20:41

    The cross-validation generator returns an iterable of length n_folds, each element of which is a 2-tuple of numpy 1-d arrays (train_index, test_index) containing the indices of the test and training sets for that cross-validation run.

    So for 10-fold cross-validation, your custom cross-validation generator needs to contain 10 elements, each of which contains a tuple with two elements:

    • An array of the indices for the training subset for that run, covering 90% of your data
    • An array of the indices for the testing subset for that run, covering 10% of the data

    I was working on a similar problem in which I created integer labels for the different folds of my data. My dataset is stored in a Pandas dataframe myDf which has the column cvLabel for the cross-validation labels. I construct the custom cross-validation generator myCViterator as follows:

    myCViterator = []
    for i in range(nFolds):
        trainIndices = myDf[ myDf['cvLabel']!=i ].index.values.astype(int)
        testIndices =  myDf[ myDf['cvLabel']==i ].index.values.astype(int)
        myCViterator.append( (trainIndices, testIndices) )
    

提交回复
热议问题