Parameter “stratify” from method “train_test_split” (scikit Learn)

前端 未结 5 1669
我寻月下人不归
我寻月下人不归 2020-12-22 21:37

I am trying to use train_test_split from package scikit Learn, but I am having trouble with parameter stratify. Hereafter is the code:



        
5条回答
  •  礼貌的吻别
    2020-12-22 22:13

    Try running this code, it "just works":

    from sklearn import cross_validation, datasets 
    
    iris = datasets.load_iris()
    
    X = iris.data[:,:2]
    y = iris.target
    
    x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)
    
    y_test
    
    array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
           1, 2, 1, 1, 0, 2, 1])
    

提交回复
热议问题