Classification metrics can't handle a mix of binary and continuous targets

前端 未结 2 666
小鲜肉
小鲜肉 2021-02-09 11:51

I try to train and test several scikit-learn models and attempt to print off the accuracy. Only some of these models work, others fail with the

ValueError: Classi         


        
相关标签:
2条回答
  • 2021-02-09 12:34

    All your commented-out models are not classifiers but regression models, for which accuracy is meaningless.

    You get the error because these regression models do not produce binary outcomes, but continuous (float) numbers (as all regression models do); so, when scikit-learn attempts to calculate the accuracy by comparing a binary number (true label) with a float (predicted value), it not unexpectedly gives an error. And this cause is clearly hinted at the error message itself:

    Classification metrics can't handle a mix of binary and continuous target
    

    Notice also that the accepted (and highly upvoted...!) answer in the question suggested at the first comment as a possible duplicate of yours is wrong; there, as here, the root cause is the use of accuracy in a LinearRegression model, which, as already said, is meaningless.

    0 讨论(0)
  • 2021-02-09 12:37

    I have used a few models for stacking using the vecstack and set needs_proba=True and then got this error. I solved it by changing the metric inside the stacking. because stacking use class prediction by default, so in case you want to have probabilities you should change the metric as well. I have defined a new function as metric:

    def get_classification_metric(testy, probs):
        from sklearn.metrics import precision_recall_curve
        precision, recall, thresholds = precision_recall_curve(testy, probs[:,1])
        # convert to f score
        fscore = (2 * precision * recall) / (precision + recall)
        # locate the index of the largest f score
        ix = np.argmax(fscore)
        return fscore[ix]
    

    This function finds the highest F1 score at optimal threshold. So only need to set metric=get_classification_metric inside the stacking function.

    0 讨论(0)
提交回复
热议问题