问题
I am trying to calculate Entropy of each class of the dataset for an image classification task to measure uncertainty on pytorch,using the MC Dropout method and the solution proposed in this link
Measuring uncertainty using MC Dropout on pytorch
First,I have calculated the mean of each class per batch across different forward passes (class_mean_batch) and then for all the testloader (classes_mean) and then did some transformations to get (total_mean) to use it for calculating Entropy as shown in the code below
def mcdropout_test(batch_size,n_classes,model,T):
#set non-dropout layers to eval mode
model.eval()
#set dropout layers to train mode
enable_dropout(model)
softmax = nn.Softmax(dim=1)
classes_mean = []
for images,labels in testloader:
images = images.to(device)
labels = labels.to(device)
classes_mean_batch = []
with torch.no_grad():
output_list = []
#getting outputs for T forward passes
for i in range(T):
output = model(images)
output = softmax(output)
output_list.append(torch.unsqueeze(output, 0))
concat_output = torch.cat(output_list,0)
# getting mean of each class per batch across multiple MCD forward passes
for i in range (n_classes):
mean = torch.mean(concat_output[:, : , i])
classes_mean_batch.append(mean)
# getting mean of each class for the testloader
classes_mean.append(torch.stack(classes_mean_batch))
total_mean = []
concat_classes_mean = torch.stack(classes_mean)
for i in range (n_classes):
concat_classes = concat_classes_mean[: , i]
total_mean.append(concat_classes)
total_mean = torch.stack(total_mean)
total_mean = np.asarray(total_mean.cpu())
epsilon = sys.float_info.min
# Calculating entropy across multiple MCD forward passes
entropy = (- np.sum(total_mean*np.log(total_mean + epsilon), axis=-1)).tolist()
for i in range(n_classes):
print(f'The uncertainty of class {i+1} is {entropy[i]:.4f}')
Can anyone please correct or confirm the implementation i have used to calculate Entropy of each class.
来源:https://stackoverflow.com/questions/63755687/calculate-entropy-for-each-class-of-the-test-set-to-measure-uncertainty-on-pytor