Simple and short question. I have a network (Unet) which performs image segmentation. I want the logits as the output to feed into the cross entropy loss (using pytorch). Cu