Tensorflow: How to ignore specific labels during semantic segmentation?

前端 未结 2 551
执笔经年
执笔经年 2021-02-15 14:49

I\'m using tensorflow for semantic segmentation. How can I tell tensorflow to ignore a specific label when computing the pixelwise loss?

I\'ve read in this post that f

2条回答
  •  灰色年华
    2021-02-15 15:02

    According to the documentation, tf.nn.softmax_cross_entropy_with_logits must be called with valid probability distributions on labels, or otherwise the computation will be incorrect, and using tf.nn.sparse_softmax_cross_entropy_with_logits (which might be more convenient in your case) with negative labels will either cause an error or return NaN values. I wouldn't rely on it to have some labels ignored.

    What I would do is to replace the logits for the ignored class with infinity in those pixels where the correct class is the ignored one, so they will contribute nothing to the loss:

    ignore_label = ...
    # Make zeros everywhere except for the ignored label
    input_batch_ignored = tf.concat(input_batch.ndims - 1,
        [tf.zeros_like(input_batch[:, :, :, :ignore_label]),
         tf.expand_dims(input_batch[:, :, :, ignore_label], -1),
         tf.zeros_like(input_batch[:, :, :, ignore_label + 1:])])
    # Make corresponding logits "infinity" (a big enough number)
    predictions_fix = tf.select(input_batch_ignored > 0,
        1e30 * tf.ones_like(predictions), predictions)
    # Compute loss with fixed logits
    loss = tf.nn.softmax_cross_entropy_with_logits(prediction, gt)
    

    The only problem with this is that you are considering that pixels of the ignored class are always predicted correctly, which means that the loss for images containing a lot of those pixels will be artificially smaller. Depending on the case this may or may not be significant, but if you want to be really accurate, you would have to weight the loss of each image according to the number of not ignored pixels, instead of just taking the mean.

    # Count relevant pixels on each image
    input_batch_relevant = 1 - input_batch_ignored
    input_batch_weight = tf.reduce_sum(input_batch_relevant, [1, 2, 3])
    # Compute relative weights
    input_batch_weight = input_batch_weight / tf.reduce_sum(input_batch_weight)
    # Compute reduced loss according to weights
    reduced_loss = tf.reduce_sum(loss * input_batch_weight)
    

提交回复
热议问题