I\'m using tensorflow for semantic segmentation. How can I tell tensorflow to ignore a specific label when computing the pixelwise loss?
I\'ve read in this post that f
According to the documentation, tf.nn.softmax_cross_entropy_with_logits must be called with valid probability distributions on labels
, or otherwise the computation will be incorrect, and using tf.nn.sparse_softmax_cross_entropy_with_logits (which might be more convenient in your case) with negative labels will either cause an error or return NaN values. I wouldn't rely on it to have some labels ignored.
What I would do is to replace the logits for the ignored class with infinity in those pixels where the correct class is the ignored one, so they will contribute nothing to the loss:
ignore_label = ...
# Make zeros everywhere except for the ignored label
input_batch_ignored = tf.concat(input_batch.ndims - 1,
[tf.zeros_like(input_batch[:, :, :, :ignore_label]),
tf.expand_dims(input_batch[:, :, :, ignore_label], -1),
tf.zeros_like(input_batch[:, :, :, ignore_label + 1:])])
# Make corresponding logits "infinity" (a big enough number)
predictions_fix = tf.select(input_batch_ignored > 0,
1e30 * tf.ones_like(predictions), predictions)
# Compute loss with fixed logits
loss = tf.nn.softmax_cross_entropy_with_logits(prediction, gt)
The only problem with this is that you are considering that pixels of the ignored class are always predicted correctly, which means that the loss for images containing a lot of those pixels will be artificially smaller. Depending on the case this may or may not be significant, but if you want to be really accurate, you would have to weight the loss of each image according to the number of not ignored pixels, instead of just taking the mean.
# Count relevant pixels on each image
input_batch_relevant = 1 - input_batch_ignored
input_batch_weight = tf.reduce_sum(input_batch_relevant, [1, 2, 3])
# Compute relative weights
input_batch_weight = input_batch_weight / tf.reduce_sum(input_batch_weight)
# Compute reduced loss according to weights
reduced_loss = tf.reduce_sum(loss * input_batch_weight)