How to use tf.data's initializable iterator and reinitializable interator and feed data to estimator api?

别等时光非礼了梦想. 提交于 2019-12-06 02:53:41

问题


All the official google tutorials use the one shot iterator for all the estimator api implementation, i couldnt find any documentation on how to use tf.data's initializable iterator and reinitializable interator instead of one shot iterator.

Can someone kindly show me how to switch between train_data and test_data using tf.data's initializable iterator and reinitializable interator. We need to run a session to use feed dict and switch the dataset in the initializable iterator, its a low level api and its confusing how to use it part of estimator api architecture

PS : I did find that google mentions "Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator."

But is there any work around within the community? or should we just stick with one shot iterator for some good reason


回答1:


To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook. This class then have access to the session used by the tf.estimator functions.

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])

This is a modified version from a code I found in a blogpost by Sebastian Pölsterl. Have a look under the "Feeding data to an Estimator via the Dataset API" section.




回答2:


Or you can simply use tf.estimator.train_and_evaluate https://www.tensorflow.org/api_docs/python/tf/estimator/train_and_evaluate It allows you to use validation during training without needing to care about iterator at all.



来源:https://stackoverflow.com/questions/51625529/how-to-use-tf-datas-initializable-iterator-and-reinitializable-interator-and-fe

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!