问题
I have successfully exported a re-trained InceptionV3 NN as a TensorFlow meta graph. I have read this protobuf back into python successfully, but I am struggling to see a way to export each layers weight and bias values, which I am assuming is stored within the meta graph protobuf, for recreating the nn outside of TensorFlow.
My workflow is as such:
Retrain final layer for new categories
Export meta graph tf.train.export_meta_graph(filename='model.meta')
Build python pb2.py using Protoc and meta_graph.proto
Load Protobuf:
import meta_graph_pb2
saved = meta_graph_pb2.CollectionDef()
with open('model.meta', 'rb') as f:
saved.ParseFromString(f.read())
From here I can view most aspects of the graph, like node names and such, but I think my inexperience is making it difficult to track down the correct way to access the weight and bias values for each relevant layer.
回答1:
The MetaGraphDef
proto doesn't actually contain the values of the weights and biases. Instead it provides a way to associate a GraphDef
with the weights stored in one or more checkpoint files, written by a tf.train.Saver. The MetaGraphDef tutorial has more details, but the approximate structure is as follows:
In you training program, write out a checkpoint using a
tf.train.Saver
. This will also write aMetaGraphDef
to a.meta
file in the same directory.saver = tf.train.Saver(...) # ... saver.save(sess, "model")
You should find files called
model.meta
andmodel-NNNN
(for some integerNNNN
) in your checkpoint directory.In another program, you can import the
MetaGraphDef
you just created, and restore from a checkpoint.saver = tf.train.import_meta_graph("model.meta") saver.restore("model-NNNN") # Or whatever checkpoint filename was written.
If you want to get the value of each variable, you can (for example) find the variable in
tf.all_variables()
collection and pass it tosess.run()
to get its value. For example, to print the values of all variables, you can do the following:for var in tf.all_variables(): print var.name, sess.run(var)
You could also filter
tf.all_variables()
to find the particular weights and biases that you're trying to extract from the model.
来源:https://stackoverflow.com/questions/39133285/weights-and-bias-from-trained-meta-graph