Evaluating Logistic regression with cross validation

前端 未结 1 1028
鱼传尺愫
鱼传尺愫 2020-12-23 15:50

I would like to use cross validation to test/train my dataset and evaluate the performance of the logistic regression model on the entire dataset and not only on the test se

相关标签:
1条回答
  • 2020-12-23 16:08

    You got it almost right. cross_validation.cross_val_predict gives you predictions for the entire dataset. You just need to remove logreg.fit earlier in the code. Specifically, what it does is the following: It divides your dataset in to n folds and in each iteration it leaves one of the folds out as the test set and trains the model on the rest of the folds (n-1 folds). So, in the end you will get predictions for the entire data.

    Let's illustrate this with one of the built-in datasets in sklearn, iris. This dataset contains 150 training samples with 4 features. iris['data'] is X and iris['target'] is y

    In [15]: iris['data'].shape
    Out[15]: (150, 4)
    

    To get predictions on the entire set with cross validation you can do the following:

    from sklearn.linear_model import LogisticRegression
    from sklearn import metrics, cross_validation
    from sklearn import datasets
    iris = datasets.load_iris()
    predicted = cross_validation.cross_val_predict(LogisticRegression(), iris['data'], iris['target'], cv=10)
    print metrics.accuracy_score(iris['target'], predicted)
    
    Out [1] : 0.9537
    
    print metrics.classification_report(iris['target'], predicted) 
    
    Out [2] :
                         precision    recall  f1-score   support
    
                    0       1.00      1.00      1.00        50
                    1       0.96      0.90      0.93        50
                    2       0.91      0.96      0.93        50
    
          avg / total       0.95      0.95      0.95       150
    

    So, back to your code. All you need is this:

    from sklearn import metrics, cross_validation
    logreg=LogisticRegression()
    predicted = cross_validation.cross_val_predict(logreg, X, y, cv=10)
    print metrics.accuracy_score(y, predicted)
    print metrics.classification_report(y, predicted) 
    

    For plotting ROC in multi-class classification, you can follow this tutorial which gives you something like the following:

    In general, sklearn has very good tutorials and documentation. I strongly recommend reading their tutorial on cross_validation.

    0 讨论(0)
提交回复
热议问题