Scikit-learn GridSearch giving “ValueError: multiclass format is not supported” error

后端 未结 3 1381
-上瘾入骨i
-上瘾入骨i 2021-02-07 12:06

I\'m trying to use GridSearch for parameter estimation of LinearSVC() as follows -

clf_SVM = LinearSVC()
params = {
          \'C\': [0.5, 1.0, 1.5],
          \         


        
相关标签:
3条回答
  • 2021-02-07 12:32

    As it has been pointed out, you must first binarize y

    y = label_binarize(y, classes=[0, 1, 2, 3])
    

    and then use a multiclass learning algorithm like OneVsRestClassifier or OneVsOneClassifier. For example:

    clf_SVM = OneVsRestClassifier(LinearSVC())
    params = {
          'estimator__C': [0.5, 1.0, 1.5],
          'estimator__tol': [1e-3, 1e-4, 1e-5],
          }
    gs = GridSearchCV(clf_SVM, params, cv=5, scoring='roc_auc')
    gs.fit(corpus1, y)
    
    0 讨论(0)
  • 2021-02-07 12:38

    from:

    http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score

    "Note: this implementation is restricted to the binary classification task or multilabel classification task in label indicator format."

    try:

    from sklearn import preprocessing
    y = preprocessing.label_binarize(y, classes=[0, 1, 2, 3])
    

    before you train. this will perform a "one-hot" encoding of your y.

    0 讨论(0)
  • 2021-02-07 12:48

    Remove scoring='roc_auc' and it will work as roc_auc curve does not support categorical data.

    0 讨论(0)
提交回复
热议问题