Loading two models from Saver in the same Tensorflow session

后端 未结 3 1711
后悔当初
后悔当初 2020-12-14 05:17

I have two networks: a Model which generates output and an Adversary which grades the output.

Both have been trained separately but now I n

相关标签:
3条回答
  • 2020-12-14 05:41

    Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.

    To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read

    Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.

    First I created two graphs

    model_graph = tf.Graph()
    with model_graph.as_default():
        model = Model(args)
    
    adv_graph = tf.Graph()
    with adv_graph.as_default():
        adversary = Adversary(adv_args)
    

    Then two sessions

    adv_sess = tf.Session(graph=adv_graph)
    sess = tf.Session(graph=model_graph)
    

    Then I initialised the variables in each session and restored each graph separately

    with sess.as_default():
        with model_graph.as_default():
            tf.global_variables_initializer().run()
            model_saver = tf.train.Saver(tf.global_variables())
            model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
            model_saver.restore(sess, model_ckpt.model_checkpoint_path)
    
    with adv_sess.as_default():
        with adv_graph.as_default():
            tf.global_variables_initializer().run()
            adv_saver = tf.train.Saver(tf.global_variables())
            adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
            adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)
    

    From here whenever each session was needed I would wrap any tf functions in that session with with sess.as_default():. At the end I manually close the sessions

    sess.close()
    adv_sess.close()
    
    0 讨论(0)
  • 2020-12-14 05:48

    The answer marked as correct does not tell us how to load two different models into one session explicitly, here is my answer:

    1. create two different name scopes for the models you want to load.

    2. initialize two savers which are going to load parameters for variables in the two different networks.

    3. load from the corresponding checkpoint files.

    with tf.Session() as sess:
        with tf.name_scope("net1"):
          net1 = Net1()
        with tf.name_scope("net2"):
          net2 = Net2()
    
        net1_varlist = {v.op.name.lstrip("net1/"): v
                        for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
        net1_saver = tf.train.Saver(var_list=net1_varlist)
    
        net2_varlist = {v.op.name.lstrip("net2/"): v
                        for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
        net2_saver = tf.train.Saver(var_list=net2_varlist)
    
        net1_saver.restore(sess, "net1.ckpt")
        net2_saver.restore(sess, "net2.ckpt")
    
    0 讨论(0)
  • 2020-12-14 05:51

    Please check this:

    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
    

    It should be "adv" , ain't it

    0 讨论(0)
提交回复
热议问题