Today I add a learning rate decay to my LSTM in Tensorflow.
I change
train_op = tf.train.RMSPropOptimizer(lr_rate).minimize(loss)
t
The problem you have has nothing to do with sess.run
or tf.assign
. This is a very popular issue in many models and your model is slow because of your bloated graph. I will explain what all of this mean on a very simple example that has nothing to do with your code. Take a look at these 2 snippets:
Snippet 1
a = tf.Variable(1, name='a')
b = tf.Variable(2, name='b')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(3):
print sess.run(tf.add(a, b)),
Snippet 2
a = tf.Variable(1, name='a')
b = tf.Variable(2, name='b')
res = tf.add(a, b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(3):
print sess.run(res),
Both of them return the same values and looks like they both do the same stuff. The problem is that they create different graphs and if you will print len(tf.get_default_graph().get_operations())
after the loop, you will see that Snippet 1 has more nodes than Snippet 2. Increase the range to a few thousand and the difference will be significant.
You have the same problem with a bloated graph. Because in each iteration of the loop you create tf.assign(lr, lr_rate*0.9**epoch)
3 nodes in the graph. Move your graph definition separately from the graph run and you will see the improvement.