I\'m using Tensorflow v1.1 and I\'ve been trying to figure out how to use my EMA\'ed weights for inference, but no matter what I do I keep getting the error
I'd like to add a method to use the trained variables in the checkpoint at best.
Keep in mind that all variables in the saver var_list should be contained in the checkpoint you configured. You can check those in the saver by:
print(restore_vars)
and those variables in the checkpoint by:
vars_in_checkpoint = tf.train.list_variables(os.path.join("model_ex1"))
in your case.
If the restore_vars are all included in vars_in_checkpoint then it will not raise the error, otherwise initialize all variables first:
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
sess.run(tf.variables_initializer(all_variables))
All variables will be initialized be those in or not in the checkpoint, then you can filter out those variables in restore_vars that are not included in the checkpoint(suppose all variable with ExponentialMovingAverage in their names are not in the checkpoint):
temp_saver = tf.train.Saver(
var_list=[v for v in all_variables if "ExponentialMovingAverage" not in v.name])
ckpt_state = tf.train.get_checkpoint_state(os.path.join("model_ex1"), lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)
This may save some time compared to training the model from scratch. (In my scenario the restored variables make no significant improvement compared to training from scratch in the beginning, since all old optimizer variables are abandoned. But it can accelerate the optimization process significantly, I think, because it is like pretraining some variables)
Anyway, some variables are useful to be restored like embeddings and some layers and etc.