I have two networks: a Model
which generates output and an Adversary
which grades the output.
Both have been trained separately but now I n
Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.
To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read
Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.
First I created two graphs
model_graph = tf.Graph()
with model_graph.as_default():
model = Model(args)
adv_graph = tf.Graph()
with adv_graph.as_default():
adversary = Adversary(adv_args)
Then two sessions
adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)
Then I initialised the variables in each session and restored each graph separately
with sess.as_default():
with model_graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
with adv_sess.as_default():
with adv_graph.as_default():
tf.global_variables_initializer().run()
adv_saver = tf.train.Saver(tf.global_variables())
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)
From here whenever each session was needed I would wrap any tf
functions in that session with with sess.as_default():
. At the end I manually close the sessions
sess.close()
adv_sess.close()
The answer marked as correct does not tell us how to load two different models into one session explicitly, here is my answer:
create two different name scopes for the models you want to load.
initialize two savers which are going to load parameters for variables in the two different networks.
load from the corresponding checkpoint files.
with tf.Session() as sess:
with tf.name_scope("net1"):
net1 = Net1()
with tf.name_scope("net2"):
net2 = Net2()
net1_varlist = {v.op.name.lstrip("net1/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)
net2_varlist = {v.op.name.lstrip("net2/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)
net1_saver.restore(sess, "net1.ckpt")
net2_saver.restore(sess, "net2.ckpt")
Please check this:
adv_varlist = {v.name.lstrip("avd/")[:-2]: v
It should be "adv" , ain't it