List of tensor names in graph in Tensorflow

后端 未结 6 1110
死守一世寂寞
死守一世寂寞 2020-11-28 21:04

The graph object in Tensorflow has a method called \"get_tensor_by_name(name)\". Is there anyway to get a list of valid tensor names?

If not, does anyone know the va

相关标签:
6条回答
  • 2020-11-28 21:12

    To see the operations in the graph (You will see many, so to cut short I have given here only the first string).

    sess = tf.Session()
    op = sess.graph.get_operations()
    [m.values() for m in op][1]
    
    out:
    (<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)
    
    0 讨论(0)
  • 2020-11-28 21:16

    The above answers are correct. I came across an easy to understand / simple code for the above task. So sharing it here :-

    import tensorflow as tf
    
    def printTensors(pb_file):
    
        # read pb into graph_def
        with tf.gfile.GFile(pb_file, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
    
        # import graph_def
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def)
    
        # print operations
        for op in graph.get_operations():
            print(op.name)
    
    
    printTensors("path-to-my-pbfile.pb")
    
    0 讨论(0)
  • 2020-11-28 21:18

    As a nested list comprehension:

    tensor_names = [t.name for op in tf.get_default_graph().get_operations() for t in op.values()]
    

    Function to get names of Tensors in a graph (defaults to default graph):

    def get_names(graph=tf.get_default_graph()):
        return [t.name for op in graph.get_operations() for t in op.values()]
    

    Function to get Tensors in a graph (defaults to default graph):

    def get_tensors(graph=tf.get_default_graph()):
        return [t for op in graph.get_operations() for t in op.values()]
    
    0 讨论(0)
  • 2020-11-28 21:23

    The paper is not accurately reflecting the model. If you download the source from arxiv it has an accurate model description as model.txt, and the names in there correlate strongly with the names in the released model.

    To answer your first question, sess.graph.get_operations() gives you a list of operations. For an op, op.name gives you the name and op.values() gives you a list of tensors it produces (in the inception-v3 model, all tensor names are the op name with a ":0" appended to it, so pool_3:0 is the tensor produced by the final pooling op.)

    0 讨论(0)
  • 2020-11-28 21:25

    You do not even have to create a session to see the names of all operation names in the graph. To do this you just need to grab a default graph tf.get_default_graph() and extract all the operations: .get_operations. Each operation has many fields, the one you need is name.

    Here is the code:

    import tensorflow as tf
    a = tf.Variable(5)
    b = tf.Variable(6)
    c = tf.Variable(7)
    d = (a + b) * c
    
    for i in tf.get_default_graph().get_operations():
        print i.name
    
    0 讨论(0)
  • 2020-11-28 21:28

    saved_model_cli is An alternative command line tool comes with TF that might be useful if your dealing with the "SavedModel" format. From the docs

    !saved_model_cli show --dir /tmp/mobilenet/1 --tag_set serve --all
    

    This output might be useful, something like:

    MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
    
    signature_def['__saved_model_init_op']:
      The given SavedModel SignatureDef contains the following input(s):
      The given SavedModel SignatureDef contains the following output(s):
        outputs['__saved_model_init_op'] tensor_info:
            dtype: DT_INVALID
            shape: unknown_rank
            name: NoOp
      Method name is: 
    
    signature_def['serving_default']:
      The given SavedModel SignatureDef contains the following input(s):
        inputs['dense_input'] tensor_info:
            dtype: DT_FLOAT
            shape: (-1, 1280)
            name: serving_default_dense_input:0
      The given SavedModel SignatureDef contains the following output(s):
        outputs['dense_1'] tensor_info:
            dtype: DT_FLOAT
            shape: (-1, 1)
            name: StatefulPartitionedCall:0
      Method name is: tensorflow/serving/predict
    
    0 讨论(0)
提交回复
热议问题