tf.GraphKeys.TRAINABLE_VARIABLES on output_graph.pb resulting in empty list

后端 未结 1 1782
故里飘歌
故里飘歌 2021-02-10 06:14

I\'m trying to extract all the weights/biases from a saved model output_graph.pb.

I read the model:

def create_graph(modelFullPath):
    \"\         


        
相关标签:
1条回答
  • 2021-02-10 06:56

    The tf.import_graph_def() function doesn't have enough information to reconstruct the tf.GraphKeys.TRAINABLE_VARIABLES collection (for that, you would need a MetaGraphDef). However, if output.pb contains a "frozen" GraphDef, then all of the weights will be stored in tf.constant() nodes in the graph. To extract them, you can do something like the following:

    create_graph(GRAPH_DIR)
    
    constant_values = {}
    
    with tf.Session() as sess:
      constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
      for constant_op in constant_ops:
        constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
    

    Note that constant_values will probably contain more values than just the weights, so you may need to filter further by op.name or some other criterion.

    0 讨论(0)
提交回复
热议问题