I am trying to implement a crude method based on the Mixture-of-Experts paper in tensorflow - https://arxiv.org/abs/1701.06538
There wo
This seems to be doable with tf.cond
:
import tensorflow as tf
def make_conditional_train_op(
should_update, optimizers, variable_lists, losses):
"""Conditionally trains variables.
Each argument is a Python list of Tensors, and each list must have the same
length. Variables are updated based on their optimizer only if the
corresponding `should_update` boolean Tensor is True at a given step.
Returns a single train op which performs the conditional updates.
"""
assert len(optimizers) == len(variable_lists)
assert len(variable_lists) == len(losses)
assert len(should_update) == len(variable_lists)
conditional_updates = []
for model_number, (update_boolean, optimizer, variables, loss) in enumerate(
zip(should_update, optimizers, variable_lists, losses)):
conditional_updates.append(
tf.cond(update_boolean,
lambda: tf.group(
optimizer.minimize(loss, var_list=variables),
tf.Print(0, ["Model {} updating".format(model_number), loss])),
lambda: tf.no_op()))
return tf.group(*conditional_updates)
The basic strategy is to make sure the optimizer's variable updates are defined in the lambda
of one of the cond
branches, in which case there is true conditional op execution, meaning that the assignment to variables (and optimizer accumulators) only happens if that branch of the cond
is triggered.
As an example, we can construct some models:
def make_model_and_optimizer():
scalar_variable = tf.get_variable("scalar", shape=[])
vector_variable = tf.get_variable("vector", shape=[3])
loss = tf.reduce_sum(scalar_variable * vector_variable)
optimizer = tf.train.AdamOptimizer(0.1)
return optimizer, [scalar_variable, vector_variable], loss
# Construct each model
optimizers = []
variable_lists = []
losses = []
for i in range(10):
with tf.variable_scope("model_{}".format(i)):
optimizer, variables, loss = make_model_and_optimizer()
optimizers.append(optimizer)
variable_lists.append(variables)
losses.append(loss)
Then determine a conditional update strategy, in this case only training the model with the maximum loss (just because that results in more switching; the output is rather boring if only one model ever updates):
# Determine which model should be updated (in this case, the one with the
# maximum loss)
integer_one_hot = tf.one_hot(
tf.argmax(tf.stack(losses),
axis=0),
depth=len(losses))
is_max = tf.equal(
integer_one_hot,
tf.ones_like(integer_one_hot))
Finally, we can call the make_conditional_train_op
function to create the train op, then do some training iterations:
train_op = make_conditional_train_op(
tf.unstack(is_max), optimizers, variable_lists, losses)
# Repeatedly call the conditional train op
with tf.Session():
tf.global_variables_initializer().run()
for i in range(20):
print("Iteration {}".format(i))
train_op.run()
This is printing the index which is updated and its loss at each iteration, confirming the conditional execution:
Iteration 0
I tensorflow/core/kernels/logging_ops.cc:79] [Model 6 updating][2.7271919]
Iteration 1
I tensorflow/core/kernels/logging_ops.cc:79] [Model 6 updating][2.1755948]
Iteration 2
I tensorflow/core/kernels/logging_ops.cc:79] [Model 2 updating][1.9858969]
Iteration 3
I tensorflow/core/kernels/logging_ops.cc:79] [Model 6 updating][1.6859927]