I have a loss function that requires multiple internal passes:
def my_loss_func(logits, sigma, labels, num_passes): total_loss = 0 img_batch_size = lo