tf.assign to variable slice doesn't work inside tf.while_loop

前端 未结 3 1143
我在风中等你
我在风中等你 2021-01-13 04:32

What is wrong with the following code? The tf.assign op works just fine when applied to a slice of a tf.Variable if it happens outside of a loop.

3条回答
  •  小蘑菇
    小蘑菇 (楼主)
    2021-01-13 05:15

    Your variable is not an output of the operations run inside your loop, it is an external entity living outside the loop. So you do not have to provide it as an argument.

    Also, you need to enforce the update to take place, for example using tf.control_dependencies in body.

    import tensorflow as tf
    
    v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    n = len(v)
    a = tf.Variable(v, name = 'a')
    
    def cond(i):
        return i < n 
    
    def body(i):
        op = tf.assign(a[i], a[i-1] + a[i-2])
        with tf.control_dependencies([op]):
          return i + 1
    
    i = tf.while_loop(cond, body, [2])
    
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    i.eval()
    print(a.eval())
    # [ 1  1  2  3  5  8 13 21 34 55 89]
    

    Possibly you may want to be cautious and set parallel_iterations=1 to enforce the loop to run sequentially.

提交回复
热议问题