How to display confusion matrix and report (recall, precision, fmeasure) for each cross validation fold

后端 未结 1 1056
北荒
北荒 2021-01-23 13:22

I am trying to perform 10 fold cross validation in python. I know how to calculate the confusion matrix and the report for split test(example split 80% training and 20% testing)

相关标签:
1条回答
  • 2021-01-23 13:49

    Here is a reproducible example with the breast cancer data and 3-fold CV for simplicity:

    from sklearn.datasets import load_breast_cancer
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.metrics import confusion_matrix, classification_report
    from sklearn.model_selection import KFold
    
    X, y = load_breast_cancer(return_X_y=True)
    n_splits = 3
    kf = KFold(n_splits=n_splits, shuffle=True)
    model = DecisionTreeClassifier()
    
    for train_index, val_index in kf.split(X):
        model.fit(X[train_index], y[train_index])
        pred = model.predict(X[val_index])
        print(confusion_matrix(y[val_index], pred))
        print(classification_report(y[val_index], pred))
    

    The result is 3 confusion matrices & classification reports, one per CV fold:

    [[ 63   9]
     [ 10 108]]
                  precision    recall  f1-score   support
    
               0       0.86      0.88      0.87        72
               1       0.92      0.92      0.92       118
    
       micro avg       0.90      0.90      0.90       190
       macro avg       0.89      0.90      0.89       190
    weighted avg       0.90      0.90      0.90       190
    
    [[ 66   8]
     [  6 110]]
                  precision    recall  f1-score   support
    
               0       0.92      0.89      0.90        74
               1       0.93      0.95      0.94       116
    
       micro avg       0.93      0.93      0.93       190
       macro avg       0.92      0.92      0.92       190
    weighted avg       0.93      0.93      0.93       190
    
    [[ 59   7]
     [  8 115]]
                  precision    recall  f1-score   support
    
               0       0.88      0.89      0.89        66
               1       0.94      0.93      0.94       123
    
       micro avg       0.92      0.92      0.92       189
       macro avg       0.91      0.91      0.91       189
    weighted avg       0.92      0.92      0.92       189
    
    0 讨论(0)
提交回复
热议问题