Update a variable with tf.while_loop in Tensorflow

前端 未结 1 358
醉酒成梦
醉酒成梦 2021-01-21 12:14

I want to update a variable in Tensorflow and for that reason I use the tf.while_loop like:

a = tf.Variable([0, 0, 0, 0, 0, 0] , dtype = np.int16)

i = tf.consta         


        
1条回答
  •  臣服心动
    2021-01-21 12:45

    This isn't obvious until one codes and executes. It is like this pattern

    import tensorflow as tf
    
    
    def cond(size, i):
        return tf.less(i,size)
    
    def body(size, i):
    
        a = tf.get_variable("a",[6],dtype=tf.int32,initializer=tf.constant_initializer(0))
        a = tf.scatter_update(a,i,i)
    
        tf.get_variable_scope().reuse_variables() # Reuse variables
        with tf.control_dependencies([a]):
            return (size, i+1)
    
    with tf.Session() as sess:
    
        i = tf.constant(0)
        size = tf.constant(6)
        _,i = tf.while_loop(cond,
                        body,
                        [size, i])
    
        a = tf.get_variable("a",[6],dtype=tf.int32)
    
        init = tf.initialize_all_variables()
        sess.run(init)
    
        print(sess.run([a,i]))
    

    Output is

    [array([0, 1, 2, 3, 4, 5]), 6]

    1. tf.get_variableGets an existing variable with these parameters or create a new one.
    2. tf.control_dependencies It is a happens-before relationship. In this case I understand that the scatter_update happens before the while increments and returns. It doesn't update without this.

    Note : I didn't really understand the meaning or cause of the error. I get that too.

    0 讨论(0)
提交回复
热议问题