Run multiple pre-trained Tensorflow nets at the same time

后端 未结 2 1457
悲&欢浪女
悲&欢浪女 2020-12-14 14:11

What I would like to do is to run multiple pre-trained Tensorflow nets at the same time. Because the names of some variables inside each net can be the same, the common solu

相关标签:
2条回答
  • 2020-12-14 14:23

    The easiest solution is to create different sessions that use separate graphs for each model:

    # Build a graph containing `net1`.
    with tf.Graph().as_default() as net1_graph:
      net1 = CreateAlexNet()
      saver1 = tf.train.Saver(...)
    sess1 = tf.Session(graph=net1_graph)
    saver1.restore(sess1, 'epoch_10.ckpt')
    
    # Build a separate graph containing `net2`.
    with tf.Graph().as_default() as net2_graph:
      net2 = CreateAlexNet()
      saver2 = tf.train.Saver(...)
    sess2 = tf.Session(graph=net1_graph)
    saver2.restore(sess2, 'epoch_50.ckpt')
    

    If this doesn't work for some reason, and you have to use a single tf.Session (e.g. because you want to combine results from the two network in another TensorFlow computation), the best solution is to:

    1. Create the different networks in name scopes as you are already doing, and
    2. Create separate tf.train.Saver instances for the two networks, with an additional argument to remap the variable names.

    When constructing the savers, you can pass a dictionary as the var_list argument, mapping the names of the variables in the checkpoint (i.e. without the name scope prefix) to the tf.Variable objects you've created in each model.

    You can build the var_list programmatically, and you should be able to do something like the following:

    with tf.name_scope("net1"):
      net1 = CreateAlexNet()
    with tf.name_scope("net2"):
      net2 = CreateAlexNet()
    
    # Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
    net1_varlist = {v.name.lstrip("net1/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
    net1_saver = tf.train.Saver(var_list=net1_varlist)
    
    # Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
    net2_varlist = {v.name.lstrip("net2/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
    net2_saver = tf.train.Saver(var_list=net2_varlist)
    
    # ...
    net1_saver.restore(sess, "epoch_10.ckpt")
    net2_saver.restore(sess, "epoch_50.ckpt")
    
    0 讨论(0)
  • 2020-12-14 14:37

    I have the same problem that bothered me a long time. I found a good solution here: Loading two models from Saver in the same Tensorflow session and TensorFlow checkpoint save and read.

    The default behavior for a tf.train.Saver() is to associate each variable with the name of the corresponding op. This means that each time you construct a tf.train.Saver(), it includes all of the variables for the previous calls. Therefore, you should create different graphs and run different sessions with them.

    0 讨论(0)
提交回复
热议问题