I am trying to run a very simple saving of a Tensorflow graph as .pb file, but I have this error when parsing it back:
Traceback (most recent call last):
F
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
Here fileContent should be a **"Frozen Graph". Tensorflow provides an api for the same as well, refer Tensorflow freeze_graph API
Another way to create frozen graph is:
with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph(<.meta file>)
saver.restore(sess, <checkpoint>)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
[comma separated output nodes name]
)
# Saving "output_graph_def " in a file and generate frozen graph.
with tf.gfile.GFile('frozen_graph.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
Use frozen_graph.pb as
graph_def.ParseFromString("frozen_graph.pb")
So first using Saver object generates the .meta and other files. Once its done create frozen graph.
The problem here is that you are trying to parse a SavedModel protocol buffer as if it were a GraphDef. Although a SavedModel
contains GraphDef
, they have different binary formats. The following code, using tf.saved_model.loader.load() should work:
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], "models/TEST-3")