I am trying to calculate class wise confusion matrix and later precision from that. Following is my function:
def calc_confusion(flat_labels, flat_logits, n_class