How to use tf.cond for batch processing

后端 未结 1 969
青春惊慌失措
青春惊慌失措 2021-01-04 11:46

I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y. Each tensor is a batch of 0/1 and I want to use th

相关标签:
1条回答
  • 2021-01-04 12:36

    tf.where sounds like what you want: a vectorized selection between Tensors.

    tf.cond is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.

    We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.

    import tensorflow as tf
    from tensorflow.python.util import nest
    
    def slicing_where(condition, full_input, true_branch, false_branch):
      """Split `full_input` between `true_branch` and `false_branch` on `condition`.
    
      Args:
        condition: A boolean Tensor with shape [B_1, ..., B_N].
        full_input: A Tensor or nested tuple of Tensors of any dtype, each with
          shape [B_1, ..., B_N, ...], to be split between `true_branch` and
          `false_branch` based on `condition`.
        true_branch: A function taking a single argument, that argument having the
          same structure and number of batch dimensions as `full_input`. Receives
          slices of `full_input` corresponding to the True entries of
          `condition`. Returns a Tensor or nested tuple of Tensors, each with batch
          dimensions matching its inputs.
        false_branch: Like `true_branch`, but receives inputs corresponding to the
          false elements of `condition`. Returns a Tensor or nested tuple of Tensors
          (with the same structure as the return value of `true_branch`), but with
          batch dimensions matching its inputs.
      Returns:
        Interleaved outputs from `true_branch` and `false_branch`, each Tensor
        having shape [B_1, ..., B_N, ...].
      """
      full_input_flat = nest.flatten(full_input)
      true_indices = tf.where(condition)
      false_indices = tf.where(tf.logical_not(condition))
      true_branch_inputs = nest.pack_sequence_as(
          structure=full_input,
          flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
                         for input_tensor in full_input_flat])
      false_branch_inputs = nest.pack_sequence_as(
          structure=full_input,
          flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
                         for input_tensor in full_input_flat])
      true_outputs = true_branch(true_branch_inputs)
      false_outputs = false_branch(false_branch_inputs)
      nest.assert_same_structure(true_outputs, false_outputs)
      def scatter_outputs(true_output, false_output):
        batch_shape = tf.shape(condition)
        scattered_shape = tf.concat(
            [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
            0)
        true_scatter = tf.scatter_nd(
            indices=tf.cast(true_indices, tf.int32),
            updates=true_output,
            shape=scattered_shape)
        false_scatter = tf.scatter_nd(
            indices=tf.cast(false_indices, tf.int32),
            updates=false_output,
            shape=scattered_shape)
        return true_scatter + false_scatter
      result = nest.pack_sequence_as(
          structure=true_outputs,
          flat_sequence=[
              scatter_outputs(true_single_output, false_single_output)
              for true_single_output, false_single_output
              in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
      return result
    

    Some examples:

    vector_test = slicing_where(
        condition=tf.equal(tf.range(10) % 2, 0),
        full_input=tf.range(10, dtype=tf.float32),
        true_branch=lambda x: 0.2 + x,
        false_branch=lambda x: 0.1 + x)
    
    cross_range = (tf.range(10, dtype=tf.float32)[:, None]
                   * tf.range(10, dtype=tf.float32)[None, :])
    matrix_test = slicing_where(
        condition=tf.equal(tf.range(10) % 3, 0),
        full_input=cross_range,
        true_branch=lambda x: -x,
        false_branch=lambda x: x + 0.1)
    
    with tf.Session():
      print(vector_test.eval())
      print(matrix_test.eval())
    

    Prints:

    [ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
      6.19999981  7.0999999   8.19999981  9.10000038]
    [[  0.           0.           0.           0.           0.           0.
        0.           0.           0.           0.        ]
     [  0.1          1.10000002   2.0999999    3.0999999    4.0999999
        5.0999999    6.0999999    7.0999999    8.10000038   9.10000038]
     [  0.1          2.0999999    4.0999999    6.0999999    8.10000038
       10.10000038  12.10000038  14.10000038  16.10000038  18.10000038]
     [  0.          -3.          -6.          -9.         -12.         -15.
      -18.         -21.         -24.         -27.        ]
     [  0.1          4.0999999    8.10000038  12.10000038  16.10000038
       20.10000038  24.10000038  28.10000038  32.09999847  36.09999847]
     [  0.1          5.0999999   10.10000038  15.10000038  20.10000038
       25.10000038  30.10000038  35.09999847  40.09999847  45.09999847]
     [  0.          -6.         -12.         -18.         -24.         -30.
      -36.         -42.         -48.         -54.        ]
     [  0.1          7.0999999   14.10000038  21.10000038  28.10000038
       35.09999847  42.09999847  49.09999847  56.09999847  63.09999847]
     [  0.1          8.10000038  16.10000038  24.10000038  32.09999847
       40.09999847  48.09999847  56.09999847  64.09999847  72.09999847]
     [  0.          -9.         -18.         -27.         -36.         -45.
      -54.         -63.         -72.         -81.        ]]
    
    0 讨论(0)
提交回复
热议问题