I\'m trying to extract all the weights/biases from a saved model output_graph.pb
.
I read the model:
def create_graph(modelFullPath):
\"\
The tf.import_graph_def()
function doesn't have enough information to reconstruct the tf.GraphKeys.TRAINABLE_VARIABLES
collection (for that, you would need a MetaGraphDef). However, if output.pb
contains a "frozen" GraphDef
, then all of the weights will be stored in tf.constant() nodes in the graph. To extract them, you can do something like the following:
create_graph(GRAPH_DIR)
constant_values = {}
with tf.Session() as sess:
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
Note that constant_values
will probably contain more values than just the weights, so you may need to filter further by op.name
or some other criterion.