问题
I am trying to create a filter which depends on the current global_step
of the training but I am failing to do so properly.
First, I cannot use tf.train.get_or_create_global_step()
in the code below because it will throw
ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
This is why I tried fetching the scope with tf.get_default_graph().get_name_scope()
and within that context I was able to "get" the global step:
def filter_examples(example):
scope = tf.get_default_graph().get_name_scope()
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
current_step = tf.train.get_or_create_global_step()
subtokens_by_step = tf.floor(current_step / curriculum_step_update)
max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)
return tf.size(example['targets']) <= max_subtokens
dataset = dataset.filter(filter_examples)
The problem with this is that it does not seem to work as I expected. From what I am observing, the current_step
in the code above seems to be 0 all the time (I don't know that, just based on my observations I assume that).
The only thing that seems to make a difference, and it sounds weird, is restarting the training. I think, also based on observations, in that case current_step
will be the actual current step of the training at this point. But the value itself won't update as the training continues.
If there a way to get the actual value of the current step and use it in my filter like above?
Environment
Tensorflow 1.12.1
回答1:
As we discussed in the comments, having and updating your own counter might be an alternative to using the global_step
variable. The counter
variable could be updated as follows:
op = tf.assign_add(counter, 1)
with tf.control_dependencies(op):
# Some operation here before which the counter should be updated
Using tf.control_dependencies allows to "attach" the update of counter
to a path within the computational graph. You can then use the counter
variable wherever you need it.
回答2:
If you use variables inside datasets you need to reinitilize iterators in tf 1.x
.
iterator = tf.compat.v1.make_initializable_iterator(dataset)
init = iterator.initializer
tensors = iterator.get_next()
with tf.compat.v1.Session() as sess:
for epoch in range(num_epochs):
sess.run(init)
for example in range(num_examples):
tensor_vals = sess.run(tensors)
来源:https://stackoverflow.com/questions/60882387/how-to-get-current-global-step-in-data-pipeline