I am doing multi class segmentation using UNet. My input to the model is HxWxC and my output is,
HxWxC
outputs = layers.Conv2D(n_classes, (1, 1), activ