How to filter tensor from queue based on some predicate in tensorflow?

后端 未结 1 2064
南方客
南方客 2021-01-03 05:12

How can I filter data stored in a queue using a predicate function? For example, let\'s say we have a queue that stores tensors of features and labels and we just need those

相关标签:
1条回答
  • 2021-01-03 05:53

    The most straightforward way to do this is to dequeue a batch, run them through the predicate test, use tf.where to produce a dense vector of the ones that match the predicate, and use tf.gather to collect the results, and enqueue that batch. If you want that to happen automatically, you can start a queue runner on the second queue - the easiest way to do that is to use tf.train.batch:

    Example:

    import numpy as np
    import tensorflow as tf
    
    a = tf.constant(np.array([5, 1, 9, 4, 7, 0], dtype=np.int32))
    
    q = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
    enqueue = q.enqueue_many([a])
    dequeue = q.dequeue_many(6)
    predmatch = tf.less(dequeue, [5])
    selected_items = tf.reshape(tf.where(predmatch), [-1])
    found = tf.gather(dequeue, selected_items)
    
    secondqueue = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
    enqueue2 = secondqueue.enqueue_many([found])
    dequeue2 = secondqueue.dequeue_many(3) # XXX, hardcoded
    
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(enqueue)  # Fill the first queue
      sess.run(enqueue2) # Filter, push into queue 2
      print sess.run(dequeue2) # Pop items off of queue2
    

    The predicate produces a boolean vector; the tf.where produces a dense vector of the indexes of the true values, and the tf.gather collects items from your original tensor based upon those indexes.

    A lot of things are hardcoded in this example that you'd need to make not-hardcoded in reality, of course, but hopefully it shows the structure of what you're trying to do (create a filtering pipeline). In practice, you'd want QueueRunners on there to keep things churning automatically. Using tf.train.batch is very useful to handle that automatically -- see Threading and Queues for more detail.

    0 讨论(0)
提交回复
热议问题