sklearn Kfold acces single fold instead of for loop

后端 未结 2 1838
-上瘾入骨i
-上瘾入骨i 2021-02-04 07:18

After using cross_validation.KFold(n, n_folds=folds) I would like to access the indexes for training and testing of single fold, instead of going through all the folds.

S

相关标签:
2条回答
  • 2021-02-04 07:59

    You are on the right track. All you need to do now is:

    kf = cross_validation.KFold(4, n_folds=2)
    mylist = list(kf)
    train, test = mylist[0]
    

    kf is actually a generator, which doesn't compute the train-test split until it is needed. This improves memory usage, as you are not storing items you don't need. Making a list of the KFold object forces it to make all values available.

    Here are two great SO question that explain what generators are: one and two


    Edit Nov 2018

    The API has changed since sklearn 0.20. An updated example (for py3.6):

    from sklearn.model_selection import KFold
    import numpy as np
    
    kf = KFold(n_splits=4)
    
    X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
    
    
    X_train, X_test = next(kf.split(X))
    
    In [12]: X_train
    Out[12]: array([2, 3])
    
    In [13]: X_test
    Out[13]: array([0, 1])
    
    0 讨论(0)
  • 2021-02-04 08:21
    # We saved all the K Fold samples in different list  then we access to this throught [i]
    from sklearn.model_selection import KFold
    import numpy as np
    import pandas as pd
    
    kf = KFold(n_splits=4)
    
    X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
    
    Y = np.array([0,0,0,1])
    Y=Y.reshape(4,1)
    
    X=pd.DataFrame(X)
    Y=pd.DataFrame(Y)
    
    
    X_train_base=[]
    X_test_base=[]
    Y_train_base=[]
    Y_test_base=[]
    
    for train_index, test_index in kf.split(X):
    
        X_train, X_test = X.iloc[train_index,:], X.iloc[test_index,:]
        Y_train, Y_test = Y.iloc[train_index,:], Y.iloc[test_index,:]
        X_train_base.append(X_train)
        X_test_base.append(X_test)
        Y_train_base.append(Y_train)
        Y_test_base.append(Y_test)
    
    print(X_train_base[0])
    print(Y_train_base[0])
    print(X_train_base[1])
    print(Y_train_base[1])
    
    0 讨论(0)
提交回复
热议问题