I\'m trying to apply a mask (binary, only one channel) to an RGB image (3 channels, normalized to [0, 1]). My current solution is, that I split the RGB image into it\'s chan
The tf.mul() operator supports numpy-style broadcasting, which would allow you to simplify and optimize the code slightly.
Let's say that zero_one_mask
is an m x n
tensor, and output_img
is a b x m x n x 3
(where b
is the batch size - I'm inferring this from the fact that you split output_img
on dimension 3)*. You can use tf.expand_dims() to make zero_one_mask
broadcastable to channels
, by reshaping it to be an m x n x 1
tensor:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
# NOTE: Assumes `output_mask` is a 2-D `m x n` tensor.
zero_one_mask = tf.expand_dims((output_mask + 1) / 2, 2)
# Apply mask to all channels.
# NOTE: Assumes `output_img` is a 4-D `b x m x n x c` tensor.
output_img = tf.mul(output_img, zero_one_mask)
(* This would work equally if output_img
were a 4-D b x m x n x c
(for any number of channels c
) or 3-D m x n x c
tensor, due to the way broadcasting works.)