Sci-kit learn how to print labels for confusion matrix?

前端 未结 5 1623
孤城傲影
孤城傲影 2021-02-13 16:52

So I\'m using sci-kit learn to classify some data. I have 13 different class values/categorizes to classify the data to. Now I have been able to use cross validation and print t

5条回答
  •  醉酒成梦
    2021-02-13 17:04

    Since confusion matrix is just a numpy matrix, it does not contain any column information. What you can do is convert your matrix into a dataframe and then print this dataframe.

    import pandas as pd
    import numpy as np
    
    def cm2df(cm, labels):
        df = pd.DataFrame()
        # rows
        for i, row_label in enumerate(labels):
            rowdata={}
            # columns
            for j, col_label in enumerate(labels): 
                rowdata[col_label]=cm[i,j]
            df = df.append(pd.DataFrame.from_dict({row_label:rowdata}, orient='index'))
        return df[labels]
    
    cm = np.arange(9).reshape((3, 3))
    df = cm2df(cm, ["a", "b", "c"])
    print(df)
    

    Code snippet is from https://gist.github.com/nickynicolson/202fe765c99af49acb20ea9f77b6255e

    Output:

       a  b  c
    a  0  1  2
    b  3  4  5
    c  6  7  8
    

提交回复
热议问题