What I would like to do is to run multiple pre-trained Tensorflow nets at the same time. Because the names of some variables inside each net can be the same, the common solu
The easiest solution is to create different sessions that use separate graphs for each model:
# Build a graph containing `net1`.
with tf.Graph().as_default() as net1_graph:
net1 = CreateAlexNet()
saver1 = tf.train.Saver(...)
sess1 = tf.Session(graph=net1_graph)
saver1.restore(sess1, 'epoch_10.ckpt')
# Build a separate graph containing `net2`.
with tf.Graph().as_default() as net2_graph:
net2 = CreateAlexNet()
saver2 = tf.train.Saver(...)
sess2 = tf.Session(graph=net1_graph)
saver2.restore(sess2, 'epoch_50.ckpt')
If this doesn't work for some reason, and you have to use a single tf.Session
(e.g. because you want to combine results from the two network in another TensorFlow computation), the best solution is to:
When constructing the savers, you can pass a dictionary as the var_list
argument, mapping the names of the variables in the checkpoint (i.e. without the name scope prefix) to the tf.Variable
objects you've created in each model.
You can build the var_list
programmatically, and you should be able to do something like the following:
with tf.name_scope("net1"):
net1 = CreateAlexNet()
with tf.name_scope("net2"):
net2 = CreateAlexNet()
# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.
net1_varlist = {v.name.lstrip("net1/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)
# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.
net2_varlist = {v.name.lstrip("net2/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)
# ...
net1_saver.restore(sess, "epoch_10.ckpt")
net2_saver.restore(sess, "epoch_50.ckpt")
I have the same problem that bothered me a long time. I found a good solution here: Loading two models from Saver in the same Tensorflow session and TensorFlow checkpoint save and read.
The default behavior for a tf.train.Saver()
is to associate each variable with the name of the corresponding op. This means that each time you construct a tf.train.Saver()
, it includes all of the variables for the previous calls. Therefore, you should create different graphs and run different sessions with them.