I am trying to create an ensemble of many trained models. All models have the same graph and just differ by its weights. I am creating the model graph using tf.get_var
There are several questions regarding this topic and a lot of possible answers/ways to do it. Here, I'd like to show how I came up with the most elegant and cleanest way of making an ensemble of N
models, where N
is arbitrary. This solution was tested with tf 1.12.0, python 2.7
The following code snippet is what you are looking for (comments below):
import tensorflow as tf
import numpy as np
num_of_ensembles = N
savers = list()
palceholders = list()
inference_ops = list()
for i in xrange(num_of_ensembles):
with tf.name_scope('model_{}'.format(i)):
savers.append(tf.train.import_meta_graph('saved_model.ckpt.meta'))
graph = tf.get_default_graph()
for i in xrange(num_of_ensembles):
placeholders.append(graph.get_operation_by_name('model_{}/input_ph'.format(i)).outputs[0])
inference_ops.append(graph.get_operation_by_name('model_{}/last_operation_in_the_network'.format(i)).outputs[0])
with tf.Session() as sess:
for i in xrange(num_of_ensembles):
savers[i].restore(sess, 'saved_model.ckpt')
prediction = sess.run(inference_ops[i], feed_dict={placeholders[i]: np.random.rand(your_input.shape)})
So, first thing to do is to import the meta graph of each model. As suggested in the comments above, the key is to create for each model from the ensemble its own scope, in order to add a prefix like model_001/, model_002/ ... to each variable scopes. This will allow you to restore N
different models, with their own independent variables.
All this graphs will live in the current default graph. Now, when you load a model you have to extract the inputs, outputs, and operations that you wish to use from the graph into new variables. To do so, you'll need to know the names of those tensors from the old model. You can inspect all saved operations using the command: ops = graph.get_operations()
. In the example above the first operation is the placeholder assignment /input_ph while the last operation was named /last_operation_in_the_network (normally, if the author of the network doesn't specify the field name
per each layer, you will find something like /dense_3, /conv2d_1 etc.). Note that it must be the exact final operation of your model and, also, you must provide the tensor which is the value .outputs[0]
of the operation itself.
Finally, you can run the session with the correct inference operation and placeholder, getting the prediction as numpy array and doing whatever you want (averaging, majority voting, etc.)
Useful links that you may want to check:
This requires a few hacks. Let us save a few simple models
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import tensorflow as tf
def build_graph(init_val=0.0):
x = tf.placeholder(tf.float32)
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--init', help='dummy string', type=float)
parser.add_argument('--path', help='dummy string', type=str)
args = parser.parse_args()
x1, y1 = build_graph(args.init)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(y1, {x1: 10})) # outputs: 10 + i
save_path = saver.save(sess, args.path)
print("Model saved in path: %s" % save_path)
# python ensemble.py --init 1 --path ./models/model1.chpt
# python ensemble.py --init 2 --path ./models/model2.chpt
# python ensemble.py --init 3 --path ./models/model3.chpt
These models produce outputs of "10 + i" where i=1, 2, 3. Note this script creates, runs and saves multiple times the same graph-structure. Loading these values and restoring each graph individually is folklore and can be done by
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import tensorflow as tf
def build_graph(init_val=0.0):
x = tf.placeholder(tf.float32)
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', help='dummy string', type=str)
args = parser.parse_args()
x1, y1 = build_graph(-5.)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, args.path)
print("Model loaded from path: %s" % args.path)
print(sess.run(y1, {x1: 10}))
# python ensemble_load.py --path ./models/model1.chpt # gives 11
# python ensemble_load.py --path ./models/model2.chpt # gives 12
# python ensemble_load.py --path ./models/model3.chpt # gives 13
These produce again the outputs 11,12,13 like expected. Now the trick is to create for each model from the ensemble its own scope like
def build_graph(x, init_val=0.0):
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y
if __name__ == '__main__':
models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
x = tf.placeholder(tf.float32)
outputs = []
for k, path in enumerate(models):
# THE VARIABLE SCOPE IS IMPORTANT
with tf.variable_scope('model_%03i' % (k + 1)):
outputs.append(build_graph(x, -100 * np.random.rand())[1])
Hence each model lives under a different variable-scope, ie. we have variables 'model_001/w:0, model_002/w:0, model_003/w:0' although they have a similar (not the same) sub-graph, these variables are indeed different objects. Now, the trick is to manage two sets of variables (those of the graph under the current scope and those from the checkpoint):
def restore_collection(path, scopename, sess):
# retrieve all variables under scope
variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
# retrieves all variables in checkpoint
for var_name, _ in tf.contrib.framework.list_variables(path):
# get the value of the variable
var_value = tf.contrib.framework.load_variable(path, var_name)
# construct expected variablename under new scope
target_var_name = '%s/%s:0' % (scopename, var_name)
# reference to variable-tensor
target_variable = variables[target_var_name]
# assign old value from checkpoint to new variable
sess.run(target_variable.assign(var_value))
The full solution would be
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
def restore_collection(path, scopename, sess):
# retrieve all variables under scope
variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
# retrieves all variables in checkpoint
for var_name, _ in tf.contrib.framework.list_variables(path):
# get the value of the variable
var_value = tf.contrib.framework.load_variable(path, var_name)
# construct expected variablename under new scope
target_var_name = '%s/%s:0' % (scopename, var_name)
# reference to variable-tensor
target_variable = variables[target_var_name]
# assign old value from checkpoint to new variable
sess.run(target_variable.assign(var_value))
def build_graph(x, init_val=0.0):
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y
if __name__ == '__main__':
models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
x = tf.placeholder(tf.float32)
outputs = []
for k, path in enumerate(models):
with tf.variable_scope('model_%03i' % (k + 1)):
outputs.append(build_graph(x, -100 * np.random.rand())[1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(outputs[0], {x: 10})) # random output -82.4929
print(sess.run(outputs[1], {x: 10})) # random output -63.65792
print(sess.run(outputs[2], {x: 10})) # random output -19.888203
print(sess.run(W[0])) # randomly initialize value -92.4929
print(sess.run(W[1])) # randomly initialize value -73.65792
print(sess.run(W[2])) # randomly initialize value -29.888203
restore_collection(models[0], 'model_001', sess) # restore all variables from different checkpoints
restore_collection(models[1], 'model_002', sess) # restore all variables from different checkpoints
restore_collection(models[2], 'model_003', sess) # restore all variables from different checkpoints
print(sess.run(W[0])) # old values from different checkpoints: 1.0
print(sess.run(W[1])) # old values from different checkpoints: 2.0
print(sess.run(W[2])) # old values from different checkpoints: 3.0
print(sess.run(outputs[0], {x: 10})) # what we expect: 11.0
print(sess.run(outputs[1], {x: 10})) # what we expect: 12.0
print(sess.run(outputs[2], {x: 10})) # what we expect: 13.0
# python ensemble_load_all.py
Now having a list of outputs, you can average these values within TensorFlow or do some other ensemble predictions.
edit: