TensorFlow: Restoring Multiple Graphs

前端 未结 1 572
抹茶落季
抹茶落季 2021-01-01 05:47

Suppose we have two TensorFlow computation graphs, G1 and G2, with saved weights W1 and W2. Assume we build a new graph <

相关标签:
1条回答
  • 2021-01-01 06:27

    You can probably use two savers where each saver looks for just one of the variables. If you just use tf.train.Saver(), I think it will look for all variables you have defined. You can give it a list of variables to look for by using tf.train.Saver([v1, ...]). For more info, you can read about the tf.train.Saver constructor here: https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver

    Here's a simple working example. Suppose you do your computation in a file "save_vars.py" and it has the following code:

    import tensorflow as tf
    
    # Graph 1 - set v1 to have value [1.0]
    g1 = tf.Graph()
    with g1.as_default():
        v1 = tf.Variable(tf.zeros([1]), name="v1")
        assign1 = v1.assign(tf.constant([1.0]))
        init1 = tf.initialize_all_variables()
        save1 = tf.train.Saver()
    
    # Graph 2 - set v2 to have value [2.0]
    g2 = tf.Graph()
    with g2.as_default():
        v2 = tf.Variable(tf.zeros([1]), name="v2")
        assign2 = v2.assign(tf.constant([2.0]))
        init2 = tf.initialize_all_variables()
        save2 = tf.train.Saver()
    
    # Do the computation for graph 1 and save
    sess1 = tf.Session(graph=g1)
    sess1.run(init1)
    print sess1.run(assign1)
    save1.save(sess1, "tmp/v1.ckpt")
    
    # Do the computation for graph 2 and save
    sess2 = tf.Session(graph=g2)
    sess2.run(init2)
    print sess2.run(assign2)
    save2.save(sess2, "tmp/v2.ckpt")
    

    If you ensure that you have a tmp directory and run python save_vars.py, you'll get the saved checkpoint files.

    Now, you can restore using a file named "restore_vars.py" with the following code:

    import tensorflow as tf
    
    # The variables v1 and v2 that we want to restore
    v1 = tf.Variable(tf.zeros([1]), name="v1")
    v2 = tf.Variable(tf.zeros([1]), name="v2")
    
    # saver1 will only look for v1
    saver1 = tf.train.Saver([v1])
    # saver2 will only look for v2
    saver2 = tf.train.Saver([v2])
    with tf.Session() as sess:
        saver1.restore(sess, "tmp/v1.ckpt")
        saver2.restore(sess, "tmp/v2.ckpt")
        print sess.run(v1)
        print sess.run(v2)
    

    and when you run python restore_vars.py, the output should be

    [1.]
    [2.]
    

    (at least on my computer that's the output). Feel free to post a comment if anything was unclear.

    0 讨论(0)
提交回复
热议问题