How to make an if statement using a boolean Tensor

前端 未结 2 688
小鲜肉
小鲜肉 2021-01-11 12:38

How do I make an if statement using a boolean tensor? To be more precise, I\'m trying to compare a tensor of size 1 to a constant, checking to see if the value in the tenso

相关标签:
2条回答
  • 2021-01-11 13:10

    TL;DR: You need to use Session.run() to get a Python boolean, but there are other ways to achieve the same result that might be more efficient.

    It looks like you've already figured out how to get a boolean tensor from your value, but for the benefit of other readers, it would look something like this:

    computed_val = ...
    constant_val = tf.constant(37.0)
    pred = tf.less(computed_val, constant_val)  # N.B. Types of the two args must match
    

    The next part is how to use it as a conditional. The simplest thing to do is to use a Python if statement, but to do that you must evaluate the tensor pred using Session.run():

    sess = tf.Session()
    
    if sess.run(pred):
      # Do something.
    else:
      # Do something else.
    

    One caveat about using a Python if statement is that you have to evaluate the whole expression up to pred, which makes it tricky to reuse intermediate values that have already been computed. I'd like to draw your attention to two other ways you can compute conditional expressions using TensorFlow, which don't require you to evaluate the predicate and get a Python value back.

    The first way uses the tf.select() op to conditionally pass through values from two tensors passed as arguments:

    pred = tf.placeholder(tf.bool)  # Can be any computed boolean expression.
    val_if_true = tf.constant(28.0)
    val_if_false = tf.constant(12.0)
    result = tf.select(pred, val_if_true, val_if_false)
    
    sess = tf.Session()
    sess.run(result, feed_dict={pred: True})   # ==> 28.0
    sess.run(result, feed_dict={pred: False})  # ==> 12.0
    

    The tf.select() op works element-wise on all of its arguments, which allows you to combine values from the two input tensors. See its documentation for more details. The drawback of tf.select() is that it evaluates both val_if_true and val_if_false before computing the result, which might be expensive if they are complicated expressions.

    The second way uses the tf.cond() op, which conditionally evaluates one of two expressions. This is particularly useful if the expressions are expensive, and it is essential if they have side effects. The basic pattern is to specify two Python functions (or lambda expressions) that build subgraphs that will execute on the true or false branches:

    # Define some large matrices
    a = ...
    b = ...
    c = ...
    
    pred = tf.placeholder(tf.bool)
    
    def if_true():
      return tf.matmul(a, b)
    
    def if_false():
      return tf.matmul(b, c)
    
    # Will be `tf.cond()` in the next release.
    from tensorflow.python.ops import control_flow_ops
    
    result = tf.cond(pred, if_true, if_false)
    
    sess = tf.Session()
    sess.run(result, feed_dict={pred: True})   # ==> executes only (a x b)
    sess.run(result, feed_dict={pred: False})  # ==> executes only (b x c)
    
    0 讨论(0)
  • 2021-01-11 13:20

    use reshape(t, []) to obtain the value and use that in your if-statement

    0 讨论(0)
提交回复
热议问题