Tensorflow: how to save/restore a model?

前端 未结 26 2570
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  心在旅途
    2020-11-21 12:23

    According to the new Tensorflow version, tf.train.Checkpoint is the preferable way of saving and restoring a model:

    Checkpoint.save and Checkpoint.restore write and read object-based checkpoints, in contrast to tf.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables when executing eagerly. Prefer tf.train.Checkpoint over tf.train.Saver for new code.

    Here is an example:

    import tensorflow as tf
    import os
    
    tf.enable_eager_execution()
    
    checkpoint_directory = "/tmp/training_checkpoints"
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
    for _ in range(num_training_steps):
      optimizer.minimize( ... )  # Variables will be restored on creation.
    status.assert_consumed()  # Optional sanity checks.
    checkpoint.save(file_prefix=checkpoint_prefix)
    

    More information and example here.

提交回复
热议问题