How to determine placeholder dependency in TensorFlow

前端 未结 1 1093
感动是毒
感动是毒 2021-01-03 10:48

Given a few symbolic variables to fetch, I need to know which placeholders are dependency.

In Theano, we have:

import theano as th
import theano.tens         


        
相关标签:
1条回答
  • 2021-01-03 11:28

    There's not an built-in function (that I know of), but it's easy to make one:

    # Setup a graph
    import tensorflow as tf
    placeholder0 = tf.placeholder(tf.float32, [])
    placeholder1 = tf.placeholder(tf.float32, [])
    constant0 = tf.constant(2.0)
    sum0 = tf.add(placeholder0, constant0)
    sum1 = tf.add(placeholder1, sum0)
    
    # Function to get *all* dependencies of a tensor.
    def get_dependencies(tensor):
        dependencies = set()
        dependencies.update(tensor.op.inputs)
        for sub_op in tensor.op.inputs:
            dependencies.update(get_dependencies(sub_op))
        return dependencies
    
    print(get_dependencies(sum0))
    print(get_dependencies(sum1))
    # Filter on type to get placeholders.
    print([tensor for tensor in get_dependencies(sum0) if tensor.op.type == 'Placeholder'])
    print([tensor for tensor in get_dependencies(sum1) if tensor.op.type == 'Placeholder'])
    

    Of course, you could throw the placeholder filtering into the function as well.

    0 讨论(0)
提交回复
热议问题