TensorFlow 0.12 Model Files

风格不统一 提交于 2019-12-01 20:24:24

问题


I train a model and save it using:

saver = tf.train.Saver()
saver.save(session, './my_model_name')

Besides the checkpoint file, which simply contains pointers to the most recent checkpoints of the model, this creates the following 3 files in the current path:

  1. my_model_name.meta
  2. my_model_name.index
  3. my_model_name.data-00000-of-00001

I wonder what each of these files contains.

I'd like to load this model in C++ and run the inference. The label_image example loads the model from a single .bp file using ReadBinaryProto(). I wonder how I can load it from these 3 files. What is the C++ equivalent of the following?

new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')

回答1:


I'm currently struggling with this myself, I've found it's not very straightforward to do currently. The two most commonly cited tutorials on the subject are: https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.goxwm1e5j and https://medium.com/@hamedmp/exporting-trained-tensorflow-models-to-c-the-right-way-cf24b609d183#.g1gak956i

The equivalent of

new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')

Is just

Status load_graph_status = LoadGraph(graph_path, &session);

Assuming you've "frozen the graph" (Used a script with combines the graph file with the checkpoint values). Also, see the discussion here: Tensorflow Different ways to Export and Run graph in C++




回答2:


What your saver creates is called "Checkpoint V2" and was introduced in TF 0.12.

I got it working quite nicely (though the docs on the C++ part are horrible, so it took me a day to solve). Some people suggest converting all variables to constants or freezing the graph, but none of these is actually needed.

Python part (saving)

with tf.Session() as sess:
    tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')

If you create the Saver with tf.trainable_variables(), you can save yourself some headache and storage space. But maybe some more complicated models need all data to be saved, then remove this argument to Saver, just make sure you're creating the Saver after your graph is created. It is also very wise to give all variables/layers unique names, otherwise you can run in different problems.

C++ part (inference)

Note that checkpointPath isn't a path to any of the existing files, just their common prefix. If you mistakenly put there path to the .index file, TF won't tell you that was wrong, but it will die during inference due to uninitialized variables.

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>

using namespace std;
using namespace tensorflow;

...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...

auto session = NewSession(SessionOptions());
if (session == nullptr) {
    throw runtime_error("Could not create Tensorflow session.");
}

Status status;

// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
    throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}

// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
    throw runtime_error("Error creating graph: " + status.ToString());
}

// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
        {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
        {},
        {graph_def.saver_def().restore_op_name()},
        nullptr);
if (!status.ok()) {
    throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}

// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);

For completeness, here's the Python equivalent:

Inference in Python

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models/my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('models/'))
    outputTensors = sess.run(outputOps, feed_dict=feedDict)


来源:https://stackoverflow.com/questions/41992346/tensorflow-0-12-model-files

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!