How to plot confusion matrix with string axis rather than integer in python

前端 未结 4 1835
星月不相逢
星月不相逢 2020-12-02 08:18

I am following a previous thread on how to plot confusion matrix in Matplotlib. The script is as follows:

from numpy import *
import matplotlib.pyplot as plt         


        
相关标签:
4条回答
  • 2020-12-02 08:59

    If you have your results stored in a csv file you can use this method directly, else you might have to make some changes to suit the structure of your results.

    Modifying example from sklearn's website:

    import itertools
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.metrics import confusion_matrix
    
    def plot_confusion_matrix(cm, classes,
                              normalize=False,
                              title='Confusion matrix',
                              cmap=plt.cm.Blues):
        """
        This function prints and plots the confusion matrix.
        Normalization can be applied by setting `normalize=True`.
        """
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print("Normalized confusion matrix")
        else:
            print('Confusion matrix, without normalization')
    
        print(cm)
    
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
    
        fmt = '.2f' if normalize else 'd'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.tight_layout()
    
    
    #Assumming that your predicted results are in csv. If not, you can still modify the example to suit your requirements
    df = pd.read_csv("dataframe.csv", index_col=0)
    
    cnf_matrix = confusion_matrix(df["actual_class_num"], df["predicted_class_num"])
    
    #getting the unique class text based on actual numerically represented classes
    unique_class_df = df.drop_duplicates(['actual_class_num','actual_class_text']).sort_values("actual_class_num")
    
    # Plot non-normalized confusion matrix
    plt.figure()
    plot_confusion_matrix(cnf_matrix, classes=unique_class_df["actual_class_text"],
                          title='Confusion matrix, without normalization')
    

    Output would look something like:

    0 讨论(0)
  • 2020-12-02 09:05

    Here is what you want:

    from string import ascii_uppercase
    from pandas import DataFrame
    import numpy as np
    import seaborn as sn
    from sklearn.metrics import confusion_matrix
    
    y_test = np.array([1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5])
    predic = np.array([1,2,4,3,5, 1,2,4,3,5, 1,2,3,4,4])
    
    columns = ['class %s' %(i) for i in list(ascii_uppercase)[0:len(np.unique(y_test))]]
    
    confm = confusion_matrix(y_test, predic)
    df_cm = DataFrame(confm, index=columns, columns=columns)
    
    ax = sn.heatmap(df_cm, cmap='Oranges', annot=True)
    

    Example image output is here:


    If you want a more complete confusion matrix as the matlab default, with totals (last line and last column), and percents on each cell, see this module below.

    Because I scoured the internet and didn't find a confusion matrix like this one on python and I developed one with theses improvements and shared on git.


    REF:

    https://github.com/wcipriano/pretty-print-confusion-matrix

    The output example is here:

    0 讨论(0)
  • 2020-12-02 09:08

    Just use matplotlib.pyplot.xticks and matplotlib.pyplot.yticks.

    E.g.

    import matplotlib.pyplot as plt
    import numpy as np
    
    plt.imshow(np.random.random((5,5)), interpolation='nearest')
    plt.xticks(np.arange(0,5), ['A', 'B', 'C', 'D', 'E'])
    plt.yticks(np.arange(0,5), ['F', 'G', 'H', 'I', 'J'])
    
    plt.show()
    

    enter image description here

    0 讨论(0)
  • 2020-12-02 09:13

    Here's what I'm guessing you want: enter image description here

    import numpy as np
    import matplotlib.pyplot as plt
    
    conf_arr = [[33,2,0,0,0,0,0,0,0,1,3], 
                [3,31,0,0,0,0,0,0,0,0,0], 
                [0,4,41,0,0,0,0,0,0,0,1], 
                [0,1,0,30,0,6,0,0,0,0,1], 
                [0,0,0,0,38,10,0,0,0,0,0], 
                [0,0,0,3,1,39,0,0,0,0,4], 
                [0,2,2,0,4,1,31,0,0,0,2],
                [0,1,0,0,0,0,0,36,0,2,0], 
                [0,0,0,0,0,0,1,5,37,5,1], 
                [3,0,0,0,0,0,0,0,0,39,0], 
                [0,0,0,0,0,0,0,0,0,0,38]]
    
    norm_conf = []
    for i in conf_arr:
        a = 0
        tmp_arr = []
        a = sum(i, 0)
        for j in i:
            tmp_arr.append(float(j)/float(a))
        norm_conf.append(tmp_arr)
    
    fig = plt.figure()
    plt.clf()
    ax = fig.add_subplot(111)
    ax.set_aspect(1)
    res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, 
                    interpolation='nearest')
    
    width, height = conf_arr.shape
    
    for x in xrange(width):
        for y in xrange(height):
            ax.annotate(str(conf_arr[x][y]), xy=(y, x), 
                        horizontalalignment='center',
                        verticalalignment='center')
    
    cb = fig.colorbar(res)
    alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
    plt.xticks(range(width), alphabet[:width])
    plt.yticks(range(height), alphabet[:height])
    plt.savefig('confusion_matrix.png', format='png')
    
    0 讨论(0)
提交回复
热议问题