Plot Confusion Matrix with scikit-learn without a Classifier

后端 未结 1 1211
花落未央
花落未央 2021-02-07 15:55

I have a confusion matrix created with sklearn.metrics.confusion_matrix.

Now, I would like to plot it with sklearn.metrics.plot_confusion_matrix

1条回答
  •  醉话见心
    2021-02-07 16:52

    The fact that you can import plot_confusion_matrix directly suggests that you have the latest version of scikit-learn (0.22) installed. So you can just look at the source code of plot_confusion_matrix() to see how its using the estimator.

    From the latest sources here, the estimator is used for:

    1. computing confusion matrix using confusion_matrix
    2. getting the labels (unique values of y which correspond to 0,1,2.. in the confusion matrix)

    So if you have those two things already, you just need the below part:

    import matplotlib.pyplot as plt
    from sklearn.metrics import ConfusionMatrixDisplay
    
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                                  display_labels=display_labels)
    
    
    # NOTE: Fill all variables here with default values of the plot_confusion_matrix
    disp = disp.plot(include_values=include_values,
                     cmap=cmap, ax=ax, xticks_rotation=xticks_rotation)
    
    plt.show()
    

    Do look at the NOTE in comment.

    For older versions, you can look at how the matplotlib part is coded here:

    • https://scikit-learn.org/0.21/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py

    0 讨论(0)
提交回复
热议问题