Stratified Train/Validation/Test-split in scikit-learn

前端 未结 2 2156
深忆病人
深忆病人 2021-02-15 14:51

There is already a description here of how to do stratified train/test split in scikit via train_test_split (Stratified Train/Test-split in scikit-learn) and a description of ho

2条回答
  •  北恋
    北恋 (楼主)
    2021-02-15 15:10

    The solution is to just use StratifiedShuffleSplit twice, like below:

    from sklearn.model_selection import StratifiedShuffleSplit
    
    split = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=42)
    for train_index, test_valid_index in split.split(df, df.target):
        train_set = df.iloc[train_index]
        test_valid_set = df.iloc[test_valid_index]
    
    split2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
    for test_index, valid_index in split2.split(test_valid_set, test_valid_set.target):
        test_set = test_valid_set.iloc[test_index]
        valid_set = test_valid_set.iloc[valid_index]
    

提交回复
热议问题