Computing scikit-learn multiclass ROC Curve with cross validation (CV)

前端 未结 1 1678
一向
一向 2021-01-07 03:06

I want to evaluate my classification models with a ROC curve. I\'m struggling with computing a multiclass ROC Curve for a cross-validated data set. There is no division in t

相关标签:
1条回答
  • 2021-01-07 03:48

    You can use label_binarizer this way and get the desired plot as output.

    Example using Iris data:

    import matplotlib.pyplot as plt
    from sklearn import svm, datasets
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import label_binarize
    from sklearn.metrics import roc_curve, auc
    from sklearn.multiclass import OneVsRestClassifier
    from sklearn.model_selection import cross_val_predict
    from sklearn.preprocessing import StandardScaler
    from sklearn.pipeline import Pipeline
    from sklearn.linear_model import LogisticRegression
    
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    
    # Binarize the output
    y_bin = label_binarize(y, classes=[0, 1, 2])
    n_classes = y_bin.shape[1]
    
    pipe= Pipeline([('scaler', StandardScaler()), ('clf', LogisticRegression())])
    # or
    #clf = OneVsRestClassifier(LogisticRegression())
    #pipe= Pipeline([('scaler', StandardScaler()), ('clf', clf)])
    y_score = cross_val_predict(pipe, X, y, cv=10 ,method='predict_proba')
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    colors = cycle(['blue', 'red', 'green'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                 ''.format(i, roc_auc[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([-0.05, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic for multi-class data')
    plt.legend(loc="lower right")
    plt.show()
    

    0 讨论(0)
提交回复
热议问题