Tensorflow: how to save/restore a model?

前端 未结 26 2575
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  栀梦
    栀梦 (楼主)
    2020-11-21 12:08

    In most cases, saving and restoring from disk using a tf.train.Saver is your best option:

    ... # build your model
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        ... # train the model
        saver.save(sess, "/tmp/my_great_model")
    
    with tf.Session() as sess:
        saver.restore(sess, "/tmp/my_great_model")
        ... # use the model
    

    You can also save/restore the graph structure itself (see the MetaGraph documentation for details). By default, the Saver saves the graph structure into a .meta file. You can call import_meta_graph() to restore it. It restores the graph structure and returns a Saver that you can use to restore the model's state:

    saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")
    
    with tf.Session() as sess:
        saver.restore(sess, "/tmp/my_great_model")
        ... # use the model
    

    However, there are cases where you need something much faster. For example, if you implement early stopping, you want to save checkpoints every time the model improves during training (as measured on the validation set), then if there is no progress for some time, you want to roll back to the best model. If you save the model to disk every time it improves, it will tremendously slow down training. The trick is to save the variable states to memory, then just restore them later:

    ... # build your model
    
    # get a handle on the graph nodes we need to save/restore the model
    graph = tf.get_default_graph()
    gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
    init_values = [assign_op.inputs[1] for assign_op in assign_ops]
    
    with tf.Session() as sess:
        ... # train the model
    
        # when needed, save the model state to memory
        gvars_state = sess.run(gvars)
    
        # when needed, restore the model state
        feed_dict = {init_value: val
                     for init_value, val in zip(init_values, gvars_state)}
        sess.run(assign_ops, feed_dict=feed_dict)
    

    A quick explanation: when you create a variable X, TensorFlow automatically creates an assignment operation X/Assign to set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dict and replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.

提交回复
热议问题