Faster method of computing confusion matrix?

前端 未结 2 1488
太阳男子
太阳男子 2021-01-07 07:03

I am computing my confusion matrix as shown below for image semantic segmentation which is a pretty verbose approach:

def confusion_matrix(preds, labels, con         


        
相关标签:
2条回答
  • 2021-01-07 07:33

    Thanks to Grigory Feldman for the answer! I had to change a few things to work with my implementation.

    For future lookers, here is my final function that sums up the percentages of each confusion matrix over a batch of inputs (to be used in training or testing loops)

    def confusion_matrix_2(y_true, y_pred, sample_sz, conf_m):
        y_pred = normalize(y_pred,0.9)
        obj = y_true[y_true==1]
        no_obj = y_true[y_true==0]
        N = torch.tensor(torch.max(torch.max(y_true), torch.max(y_pred)) + 1,dtype=torch.int)
        y_true = torch.tensor(y_true, dtype=torch.long)
        y_pred = torch.tensor(y_pred, dtype=torch.long)
        y = N * y_true + y_pred
        y = torch.bincount(y.flatten())
        if len(y) < N * N:
            y = torch.cat((y, torch.zeros(N * N - len(y), dtype=torch.long)))
        y = y.reshape(N.item(), N.item())
        y = y.float()
        conf_m[0,:] += y[0,:]/(len(no_obj)*sample_sz)
        conf_m[1,:] += y[1,:]/(len(obj)*sample_sz)
        return conf_m
    
    ...
    conf_m = torch.zeros((2, 2),dtype=torch.float) # two classes (object or no-object)
    for _, data in enumerate(dataloader):
        for img,label in enumerate(data):
            ...
            out = Net(img)
            conf_m = confusion_matrix(out, label, len(data))
            ...
        ...
    
    0 讨论(0)
  • 2021-01-07 07:35

    I use these 2 functions to calc confusion matrix (as it defined in sklearn):

    # rewrite sklearn method to torch
    def confusion_matrix_1(y_true, y_pred):
        N = max(max(y_true), max(y_pred)) + 1
        y_true = torch.tensor(y_true, dtype=torch.long)
        y_pred = torch.tensor(y_pred, dtype=torch.long)
        return torch.sparse.LongTensor(
            torch.stack([y_true, y_pred]), 
            torch.ones_like(y_true, dtype=torch.long),
            torch.Size([N, N])).to_dense()
    
    # weird trick with bincount
    def confusion_matrix_2(y_true, y_pred):
        N = max(max(y_true), max(y_pred)) + 1
        y_true = torch.tensor(y_true, dtype=torch.long)
        y_pred = torch.tensor(y_pred, dtype=torch.long)
        y = N * y_true + y_pred
        y = torch.bincount(y)
        if len(y) < N * N:
            y = torch.cat(y, torch.zeros(N * N - len(y), dtype=torch.long))
        y = y.reshape(N, N)
        return y
    
    y_true = [2, 0, 2, 2, 0, 1]
    y_pred = [0, 0, 2, 2, 0, 2]
    
    confusion_matrix_1(y_true, y_pred)
    # tensor([[2, 0, 0],
    #         [0, 0, 1],
    #         [1, 0, 2]])
    
    

    Second function is faster in case of small number of classes.

    %%timeit
    confusion_matrix_1(y_true, y_pred)
    # 102 µs ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    %%timeit
    confusion_matrix_2(y_true, y_pred)
    # 25 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    0 讨论(0)
提交回复
热议问题