Computing scikit-learn multiclass ROC Curve with cross validation (CV)

核能气质少年 提交于 2021-02-04 16:31:10

问题


I want to evaluate my classification models with a ROC curve. I'm struggling with computing a multiclass ROC Curve for a cross-validated data set. There is no division in train and test set, because of the cross-validation. Underneath, you can see the code I already tried.

   scaler = StandardScaler(with_mean=False) 

   enc = LabelEncoder()
   y = enc.fit_transform(labels)
   vec = DictVectorizer()

   feat_sel = SelectKBest(mutual_info_classif, k=200)    

   n_classes = 3

# Pipeline for computing of ROC curves                 
  clf = OneVsRestClassifier(LogisticRegression(solver='newton-cg', multi_class='multinomial'))
  clf = clf.label_binarizer_
  pipe = Pipeline([('vectorizer', vec),
             ('scaler', scaler),
             ('Logreg', clf),
             ('mutual_info',feat_sel)])

  y_pred = model_selection.cross_val_predict(pipe, instances, y, cv=10) 


  fpr = dict()
  tpr = dict()
  roc_auc = dict()
  for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Plot of a ROC curve for a specific class
for i in range(n_classes):
     plt.figure()
     plt.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f)' % roc_auc[i])
     plt.plot([0, 1], [0, 1], 'k--')
     plt.xlim([0.0, 1.0])
     plt.ylim([0.0, 1.05])
     plt.xlabel('False Positive Rate')
     plt.ylabel('True Positive Rate')
     plt.title('Receiver operating characteristic example')
     plt.legend(loc="lower right")
     plt.show()

I thought I could binarize my y_pred by using the attribute label_binarizer_ for the OneVsRestclassifier as mentioned here: sklearn.multiclass.OneVsRestclassifier.

However, I get the following error: AttributeError: 'OneVsRestClassifier' object has no attribute 'label_binarizer_'. I don't get this error, because the documentation tells me that it is an attribute from this classifier.

when I add instances = DataFrame(instances) and clf.fit(instances, y), I get the error: ValueError: Input contains NaN, infinity or a value too large for dtype('float64'). . Instances is a list of feature vector dictionaries. I tried adding instances = np.array(instances) instead, but this gives me this error: TypeError: float() argument must be a string or a number, not 'dict'

What am I doing wrong?


回答1:


You can use label_binarizer this way and get the desired plot as output.

Example using Iris data:

import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

iris = datasets.load_iris()
X = iris.data
y = iris.target

# Binarize the output
y_bin = label_binarize(y, classes=[0, 1, 2])
n_classes = y_bin.shape[1]

pipe= Pipeline([('scaler', StandardScaler()), ('clf', LogisticRegression())])
# or
#clf = OneVsRestClassifier(LogisticRegression())
#pipe= Pipeline([('scaler', StandardScaler()), ('clf', clf)])
y_score = cross_val_predict(pipe, X, y, cv=10 ,method='predict_proba')

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
colors = cycle(['blue', 'red', 'green'])
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([-0.05, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic for multi-class data')
plt.legend(loc="lower right")
plt.show()



来源:https://stackoverflow.com/questions/45641409/computing-scikit-learn-multiclass-roc-curve-with-cross-validation-cv

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!