Stratified Train/Test-split in scikit-learn

后端 未结 7 2124

I need to split my data into a training set (75%) and test set (25%). I currently do that with the code below:

X, Xt, userInfo, userInfo_train = sklearn.cros         


        
相关标签:
7条回答
  • 2020-11-27 03:13

    In addition to the accepted answer by @Andreas Mueller, just want to add that as @tangy mentioned above:

    StratifiedShuffleSplit most closely resembles train_test_split(stratify = y) with added features of:

    1. stratify by default
    2. by specifying n_splits, it repeatedly splits the data
    0 讨论(0)
  • 2020-11-27 03:18

    Updating @tangy answer from above to the current version of scikit-learn: 0.23.2 (StratifiedShuffleSplit documentation).

    from sklearn.model_selection import StratifiedShuffleSplit
    
    n_splits = 1  # We only want a single split in this case
    sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=0)
    
    for train_index, test_index in sss.split(X, y):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
    
    0 讨论(0)
  • 2020-11-27 03:19

    You can simply do it with train_test_split() method available in Scikit learn:

    from sklearn.model_selection import train_test_split 
    train, test = train_test_split(X, test_size=0.25, stratify=X['YOUR_COLUMN_LABEL']) 
    

    I have also prepared a short GitHub Gist which shows how stratify option works:

    https://gist.github.com/SHi-ON/63839f3a3647051a180cb03af0f7d0d9

    0 讨论(0)
  • 2020-11-27 03:20

    [update for 0.17]

    See the docs of sklearn.model_selection.train_test_split:

    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        stratify=y, 
                                                        test_size=0.25)
    

    [/update for 0.17]

    There is a pull request here. But you can simply do train, test = next(iter(StratifiedKFold(...))) and use the train and test indices if you want.

    0 讨论(0)
  • 2020-11-27 03:23

    TL;DR : Use StratifiedShuffleSplit with test_size=0.25

    Scikit-learn provides two modules for Stratified Splitting:

    1. StratifiedKFold : This module is useful as a direct k-fold cross-validation operator: as in it will set up n_folds training/testing sets such that classes are equally balanced in both.

    Heres some code(directly from above documentation)

    >>> skf = cross_validation.StratifiedKFold(y, n_folds=2) #2-fold cross validation
    >>> len(skf)
    2
    >>> for train_index, test_index in skf:
    ...    print("TRAIN:", train_index, "TEST:", test_index)
    ...    X_train, X_test = X[train_index], X[test_index]
    ...    y_train, y_test = y[train_index], y[test_index]
    ...    #fit and predict with X_train/test. Use accuracy metrics to check validation performance
    
    1. StratifiedShuffleSplit : This module creates a single training/testing set having equally balanced(stratified) classes. Essentially this is what you want with the n_iter=1. You can mention the test-size here same as in train_test_split

    Code:

    >>> sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0)
    >>> len(sss)
    1
    >>> for train_index, test_index in sss:
    ...    print("TRAIN:", train_index, "TEST:", test_index)
    ...    X_train, X_test = X[train_index], X[test_index]
    ...    y_train, y_test = y[train_index], y[test_index]
    >>> # fit and predict with your classifier using the above X/y train/test
    
    0 讨论(0)
  • 2020-11-27 03:29
    #train_size is 1 - tst_size - vld_size
    tst_size=0.15
    vld_size=0.15
    
    X_train_test, X_valid, y_train_test, y_valid = train_test_split(df.drop(y, axis=1), df.y, test_size = vld_size, random_state=13903) 
    
    X_train_test_V=pd.DataFrame(X_train_test)
    X_valid=pd.DataFrame(X_valid)
    
    X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, test_size=tst_size, random_state=13903)
    
    0 讨论(0)
提交回复
热议问题