Why does tf.assign() slow the execution time?

前端 未结 3 1865
生来不讨喜
生来不讨喜 2021-01-20 16:23

Today I add a learning rate decay to my LSTM in Tensorflow.

I change

train_op = tf.train.RMSPropOptimizer(lr_rate).minimize(loss)

t

相关标签:
3条回答
  • 2021-01-20 17:09

    The problem you have has nothing to do with sess.run or tf.assign. This is a very popular issue in many models and your model is slow because of your bloated graph. I will explain what all of this mean on a very simple example that has nothing to do with your code. Take a look at these 2 snippets:

    Snippet 1

    a = tf.Variable(1, name='a')
    b = tf.Variable(2, name='b')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(3):
            print sess.run(tf.add(a, b)),
    

    Snippet 2

    a = tf.Variable(1, name='a')
    b = tf.Variable(2, name='b')
    res = tf.add(a, b)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(3):
            print sess.run(res),
    

    Both of them return the same values and looks like they both do the same stuff. The problem is that they create different graphs and if you will print len(tf.get_default_graph().get_operations()) after the loop, you will see that Snippet 1 has more nodes than Snippet 2. Increase the range to a few thousand and the difference will be significant.

    You have the same problem with a bloated graph. Because in each iteration of the loop you create tf.assign(lr, lr_rate*0.9**epoch) 3 nodes in the graph. Move your graph definition separately from the graph run and you will see the improvement.

    0 讨论(0)
  • 2021-01-20 17:16

    While I'm not sure why this could slow down the process this much (in https://github.com/tensorflow/tensorflow/issues/1439 it seems constantly creating new graph nodes can cause this), I think it is better to use feed_dict to do this:

    learn_rate = tf.placeholder(tf.float32, shape=[])
    optiizer = tf.train.AdamOptimizer(learn_rate)
    ...
    learnrate = 1e-5
    ...
    sess.run(minimizer, feed_dict={learn_rate: learnrate})
    

    I use this approach and I see no performance issue. Moreover, you can pass an arbitrary number, so you can even increase/decrease learning rate based on error on train/validation data.

    0 讨论(0)
  • 2021-01-20 17:17

    An increase of computation time by 3 seems a bit odd but here are some things you can try:

    • create an op in the graph to update your learning rate. In your code, you create a new operation at each step, which is added to the graph so it might take extra time. In general, it's best practice to create all the necessary operations before the tf.Session()
    update_lr = tf.assign(lr, lr_rate*0.9**epoch)
    
    • use only 1 sess.run() at each iteration, combining training op and update_lr
    sess.run([train_op, update_lr], ...)
    
    • the more efficient way to implement a decayed learning rate is to use tf.train.exponential_decay(). If you want to decay by 0.9 every epoch, you can do:
    training_size = 60000  # size of an epoch
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = 0.1
    learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                               training_size, 0.9, staircase=True)
    # Passing global_step to minimize() will increment it at each step.
    
    train_op = tf.train.RMSPropOptimizer(lr_rate).minimize(loss, global_step=global_step)
    
    0 讨论(0)
提交回复
热议问题