How to get current global_step in data pipeline

半城伤御伤魂 提交于 2020-04-07 07:07:38

问题


I am trying to create a filter which depends on the current global_step of the training but I am failing to do so properly.

First, I cannot use tf.train.get_or_create_global_step() in the code below because it will throw

ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

This is why I tried fetching the scope with tf.get_default_graph().get_name_scope() and within that context I was able to "get" the global step:

def filter_examples(example):
    scope = tf.get_default_graph().get_name_scope()

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        current_step = tf.train.get_or_create_global_step()

    subtokens_by_step = tf.floor(current_step / curriculum_step_update)
    max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)

    return tf.size(example['targets']) <= max_subtokens


dataset = dataset.filter(filter_examples)

The problem with this is that it does not seem to work as I expected. From what I am observing, the current_step in the code above seems to be 0 all the time (I don't know that, just based on my observations I assume that).

The only thing that seems to make a difference, and it sounds weird, is restarting the training. I think, also based on observations, in that case current_step will be the actual current step of the training at this point. But the value itself won't update as the training continues.

If there a way to get the actual value of the current step and use it in my filter like above?


Environment

Tensorflow 1.12.1


回答1:


As we discussed in the comments, having and updating your own counter might be an alternative to using the global_step variable. The counter variable could be updated as follows:

op = tf.assign_add(counter, 1)
with tf.control_dependencies(op): 
    # Some operation here before which the counter should be updated

Using tf.control_dependencies allows to "attach" the update of counter to a path within the computational graph. You can then use the counter variable wherever you need it.




回答2:


If you use variables inside datasets you need to reinitilize iterators in tf 1.x.

iterator = tf.compat.v1.make_initializable_iterator(dataset)
init = iterator.initializer
tensors = iterator.get_next()

with tf.compat.v1.Session() as sess:
    for epoch in range(num_epochs):
        sess.run(init)
        for example in range(num_examples):
            tensor_vals = sess.run(tensors)


来源:https://stackoverflow.com/questions/60882387/how-to-get-current-global-step-in-data-pipeline

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!