Tensorflow: how to save/restore a model?

前端 未结 26 2590
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  一生所求
    2020-11-21 12:20

    If it is an internally saved model, you just specify a restorer for all variables as

    restorer = tf.train.Saver(tf.all_variables())
    

    and use it to restore variables in a current session:

    restorer.restore(self._sess, model_file)
    

    For the external model you need to specify the mapping from the its variable names to your variable names. You can view the model variable names using the command

    python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt
    

    The inspect_checkpoint.py script can be found in './tensorflow/python/tools' folder of the Tensorflow source.

    To specify the mapping, you can use my Tensorflow-Worklab, which contains a set of classes and scripts to train and retrain different models. It includes an example of retraining ResNet models, located here

提交回复
热议问题