Why does tf.assign() slow the execution time?

前端 未结 3 1864
生来不讨喜
生来不讨喜 2021-01-20 16:23

Today I add a learning rate decay to my LSTM in Tensorflow.

I change

train_op = tf.train.RMSPropOptimizer(lr_rate).minimize(loss)

t

3条回答
  •  无人共我
    2021-01-20 17:09

    The problem you have has nothing to do with sess.run or tf.assign. This is a very popular issue in many models and your model is slow because of your bloated graph. I will explain what all of this mean on a very simple example that has nothing to do with your code. Take a look at these 2 snippets:

    Snippet 1

    a = tf.Variable(1, name='a')
    b = tf.Variable(2, name='b')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(3):
            print sess.run(tf.add(a, b)),
    

    Snippet 2

    a = tf.Variable(1, name='a')
    b = tf.Variable(2, name='b')
    res = tf.add(a, b)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(3):
            print sess.run(res),
    

    Both of them return the same values and looks like they both do the same stuff. The problem is that they create different graphs and if you will print len(tf.get_default_graph().get_operations()) after the loop, you will see that Snippet 1 has more nodes than Snippet 2. Increase the range to a few thousand and the difference will be significant.

    You have the same problem with a bloated graph. Because in each iteration of the loop you create tf.assign(lr, lr_rate*0.9**epoch) 3 nodes in the graph. Move your graph definition separately from the graph run and you will see the improvement.

提交回复
热议问题