Parameter “stratify” from method “train_test_split” (scikit Learn)

前端 未结 5 1670
我寻月下人不归
我寻月下人不归 2020-12-22 21:37

I am trying to use train_test_split from package scikit Learn, but I am having trouble with parameter stratify. Hereafter is the code:



        
相关标签:
5条回答
  • 2020-12-22 22:04

    For my future self who comes here via Google:

    train_test_split is now in model_selection, hence:

    from sklearn.model_selection import train_test_split
    
    # given:
    # features: xs
    # ground truth: ys
    
    x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                        test_size=0.33,
                                                        random_state=0,
                                                        stratify=ys)
    

    is the way to use it. Setting the random_state is desirable for reproducibility.

    0 讨论(0)
  • 2020-12-22 22:09

    This stratify parameter makes a split so that the proportion of values in the sample produced will be the same as the proportion of values provided to parameter stratify.

    For example, if variable y is a binary categorical variable with values 0 and 1 and there are 25% of zeros and 75% of ones, stratify=y will make sure that your random split has 25% of 0's and 75% of 1's.

    0 讨论(0)
  • 2020-12-22 22:11

    In this context, stratification means that the train_test_split method returns training and test subsets that have the same proportions of class labels as the input dataset.

    0 讨论(0)
  • 2020-12-22 22:13

    Try running this code, it "just works":

    from sklearn import cross_validation, datasets 
    
    iris = datasets.load_iris()
    
    X = iris.data[:,:2]
    y = iris.target
    
    x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)
    
    y_test
    
    array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
           1, 2, 1, 1, 0, 2, 1])
    
    0 讨论(0)
  • 2020-12-22 22:21

    Scikit-Learn is just telling you it doesn't recognise the argument "stratify", not that you're using it incorrectly. This is because the parameter was added in version 0.17 as indicated in the documentation you quoted.

    So you just need to update Scikit-Learn.

    0 讨论(0)
提交回复
热议问题