I want to update a variable in Tensorflow and for that reason I use the tf.while_loop like:
a = tf.Variable([0, 0, 0, 0, 0, 0] , dtype = np.int16)
i = tf.consta
This isn't obvious until one codes and executes. It is like this pattern
import tensorflow as tf
def cond(size, i):
return tf.less(i,size)
def body(size, i):
a = tf.get_variable("a",[6],dtype=tf.int32,initializer=tf.constant_initializer(0))
a = tf.scatter_update(a,i,i)
tf.get_variable_scope().reuse_variables() # Reuse variables
with tf.control_dependencies([a]):
return (size, i+1)
with tf.Session() as sess:
i = tf.constant(0)
size = tf.constant(6)
_,i = tf.while_loop(cond,
body,
[size, i])
a = tf.get_variable("a",[6],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([a,i]))
Output is
[array([0, 1, 2, 3, 4, 5]), 6]
scatter_update
happens before the while
increments and returns. It doesn't update without this.Note : I didn't really understand the meaning or cause of the error. I get that too.