Tensorflow: how to save/restore a model?

前端 未结 26 2584
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  别跟我提以往
    2020-11-21 12:19

    Here's my simple solution for the two basic cases differing on whether you want to load the graph from file or build it during runtime.

    This answer holds for Tensorflow 0.12+ (including 1.0).

    Rebuilding the graph in code

    Saving

    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')
    

    Loading

    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        # now you can use the graph, continue training or whatever
    

    Loading also the graph from a file

    When using this technique, make sure all your layers/variables have explicitly set unique names. Otherwise Tensorflow will make the names unique itself and they'll be thus different from the names stored in the file. It's not a problem in the previous technique, because the names are "mangled" the same way in both loading and saving.

    Saving

    graph = ... # build the graph
    
    for op in [ ... ]:  # operators you want to use after restoring the model
        tf.add_to_collection('ops_to_restore', op)
    
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')
    

    Loading

    with ... as sess:  # your session object
        saver = tf.train.import_meta_graph('my-model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection
    

提交回复
热议问题