In Tensorflow, get the names of all the Tensors in a graph

前端 未结 10 648
独厮守ぢ
独厮守ぢ 2020-11-27 10:29

I am creating neural nets with Tensorflow and skflow; for some reason I want to get the values of some inner tensors for a given input, so I am usi

相关标签:
10条回答
  • 2020-11-27 10:31

    The accepted answer only gives you a list of strings with the names. I prefer a different approach, which gives you (almost) direct access to the tensors:

    graph = tf.get_default_graph()
    list_of_tuples = [op.values() for op in graph.get_operations()]
    

    list_of_tuples now contains every tensor, each within a tuple. You could also adapt it to get the tensors directly:

    graph = tf.get_default_graph()
    list_of_tuples = [op.values()[0] for op in graph.get_operations()]
    
    0 讨论(0)
  • 2020-11-27 10:36

    tf.all_variables() can get you the information you want.

    Also, this commit made today in TensorFlow Learn that provides a function get_variable_names in estimator that you can use to retrieve all variable names easily.

    0 讨论(0)
  • 2020-11-27 10:36

    I think this will do too:

    print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
    

    But compared with Salvado and Yaroslav's answers, I don't know which one is better.

    0 讨论(0)
  • 2020-11-27 10:38

    Since the OP asked for the list of the tensors instead of the list of operations/nodes, the code should be slightly different:

    graph = tf.get_default_graph()    
    tensors_per_node = [node.values() for node in graph.get_operations()]
    tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]
    
    0 讨论(0)
  • 2020-11-27 10:42

    This worked for me:

    for n in tf.get_default_graph().as_graph_def().node:
        print('\n',n)
    
    0 讨论(0)
  • 2020-11-27 10:43

    The following solution works for me in TensorFlow 2.3 -

    def load_pb(path_to_pb):
        with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')
            return graph
    
    tf_graph = load_pb(MODEL_FILE)
    sess = tf.compat.v1.Session(graph=tf_graph)
    
    # Show tensor names in graph
    for op in tf_graph.get_operations():
        print(op.values())
    

    where MODEL_FILE is the path to your frozen graph.

    Taken from here.

    0 讨论(0)
提交回复
热议问题