I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y
. Each tensor is a batch of 0/1 and I want to use th
tf.where sounds like what you want: a vectorized selection between Tensors.
tf.cond
is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.
We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.
import tensorflow as tf
from tensorflow.python.util import nest
def slicing_where(condition, full_input, true_branch, false_branch):
"""Split `full_input` between `true_branch` and `false_branch` on `condition`.
Args:
condition: A boolean Tensor with shape [B_1, ..., B_N].
full_input: A Tensor or nested tuple of Tensors of any dtype, each with
shape [B_1, ..., B_N, ...], to be split between `true_branch` and
`false_branch` based on `condition`.
true_branch: A function taking a single argument, that argument having the
same structure and number of batch dimensions as `full_input`. Receives
slices of `full_input` corresponding to the True entries of
`condition`. Returns a Tensor or nested tuple of Tensors, each with batch
dimensions matching its inputs.
false_branch: Like `true_branch`, but receives inputs corresponding to the
false elements of `condition`. Returns a Tensor or nested tuple of Tensors
(with the same structure as the return value of `true_branch`), but with
batch dimensions matching its inputs.
Returns:
Interleaved outputs from `true_branch` and `false_branch`, each Tensor
having shape [B_1, ..., B_N, ...].
"""
full_input_flat = nest.flatten(full_input)
true_indices = tf.where(condition)
false_indices = tf.where(tf.logical_not(condition))
true_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
for input_tensor in full_input_flat])
false_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
for input_tensor in full_input_flat])
true_outputs = true_branch(true_branch_inputs)
false_outputs = false_branch(false_branch_inputs)
nest.assert_same_structure(true_outputs, false_outputs)
def scatter_outputs(true_output, false_output):
batch_shape = tf.shape(condition)
scattered_shape = tf.concat(
[batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
0)
true_scatter = tf.scatter_nd(
indices=tf.cast(true_indices, tf.int32),
updates=true_output,
shape=scattered_shape)
false_scatter = tf.scatter_nd(
indices=tf.cast(false_indices, tf.int32),
updates=false_output,
shape=scattered_shape)
return true_scatter + false_scatter
result = nest.pack_sequence_as(
structure=true_outputs,
flat_sequence=[
scatter_outputs(true_single_output, false_single_output)
for true_single_output, false_single_output
in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
return result
Some examples:
vector_test = slicing_where(
condition=tf.equal(tf.range(10) % 2, 0),
full_input=tf.range(10, dtype=tf.float32),
true_branch=lambda x: 0.2 + x,
false_branch=lambda x: 0.1 + x)
cross_range = (tf.range(10, dtype=tf.float32)[:, None]
* tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
condition=tf.equal(tf.range(10) % 3, 0),
full_input=cross_range,
true_branch=lambda x: -x,
false_branch=lambda x: x + 0.1)
with tf.Session():
print(vector_test.eval())
print(matrix_test.eval())
Prints:
[ 0.2 1.10000002 2.20000005 3.0999999 4.19999981 5.0999999
6.19999981 7.0999999 8.19999981 9.10000038]
[[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[ 0.1 1.10000002 2.0999999 3.0999999 4.0999999
5.0999999 6.0999999 7.0999999 8.10000038 9.10000038]
[ 0.1 2.0999999 4.0999999 6.0999999 8.10000038
10.10000038 12.10000038 14.10000038 16.10000038 18.10000038]
[ 0. -3. -6. -9. -12. -15.
-18. -21. -24. -27. ]
[ 0.1 4.0999999 8.10000038 12.10000038 16.10000038
20.10000038 24.10000038 28.10000038 32.09999847 36.09999847]
[ 0.1 5.0999999 10.10000038 15.10000038 20.10000038
25.10000038 30.10000038 35.09999847 40.09999847 45.09999847]
[ 0. -6. -12. -18. -24. -30.
-36. -42. -48. -54. ]
[ 0.1 7.0999999 14.10000038 21.10000038 28.10000038
35.09999847 42.09999847 49.09999847 56.09999847 63.09999847]
[ 0.1 8.10000038 16.10000038 24.10000038 32.09999847
40.09999847 48.09999847 56.09999847 64.09999847 72.09999847]
[ 0. -9. -18. -27. -36. -45.
-54. -63. -72. -81. ]]