I am trying to implement Adversarial NN, which requires to \'freeze\' one or the other part of the graph during alternating training minibatches. I.e. there two sub-networks
Another option you might want to consider is you can set trainable=False on a variable. Which means it will not be modified by training.
tf.Variable(my_weights, trainable=False)
The easiest way to achieve this, as you mention in your question, is to create two optimizer operations using separate calls to opt.minimize(cost, ...)
. By default, the optimizer will use all of the variables in tf.trainable_variables(). If you want to filter the variables to a particular scope, you can use the optional scope
argument to tf.get_collection() as follows:
optimizer = tf.train.AdagradOptimzer(0.01)
first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
first_train_op = optimizer.minimize(cost, var_list=first_train_vars)
second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
second_train_op = optimizer.minimize(cost, var_list=second_train_vars)
I don't know if my approach has down sides, but I solved this issue for myself with this construct:
do_gradient = <Tensor that evaluates to 0 or 1>
no_gradient = 1 - do_gradient
wrapped_op = do_gradient * original + no_gradient * tf.stop_gradient(original)
So if do_gradient = 1
, the values and gradients will flow through just fine, but if do_gradient = 0
, then the values will only flow through the stop_gradient op, which will stop the gradients flowing back.
For my scenario, hooking do_gradient up to an index of a random_shuffle tensor let me randomly train different pieces of my network.
@mrry's answer is completely right and perhaps more general than what I'm about to suggest. But I think a simpler way to accomplish it is to just pass the python reference directly to var_list
W = tf.Variable(...)
C = tf.Variable(...)
Y_est = tf.matmul(W,C)
loss = tf.reduce_sum((data-Y_est)**2)
optimizer = tf.train.AdamOptimizer(0.001)
# You can pass the python object directly
train_W = optimizer.minimize(loss, var_list=[W])
train_C = optimizer.minimize(loss, var_list=[C])
I have a self-contained example here: https://gist.github.com/ahwillia/8cedc710352eb919b684d8848bc2df3a