Avoiding tf.data.Dataset.from_tensor_slices with estimator api

我与影子孤独终老i 提交于 2021-02-18 12:35:10

问题


I'm am trying to figure out the recommended way to use the dataset api together with the estimator api. Everything I have seen online is some variation of this:

def train_input_fn():
   dataset = tf.data.Dataset.from_tensor_slices((features, labels))
   return dataset

which can then be passed to the estimator's train function:

 classifier.train(
    input_fn=train_input_fn,
    #...
 )

but the dataset guide warns that:

the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.

and then describes a method that involves defining placeholders which are then filled with the feed_dict:

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

But if you're using the estimator api, you're not manually running the session. So how do you use the dataset api with estimators while avoiding the problems associated with from_tensor_slices()?


回答1:


To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook, which has access to the session at multiple times during training and evaluation steps.

You can then use this new class to initialize the iterator has you would normally do in a classic setting. You simply need to pass this newly created hook to the training/evaluation functions or to the correct train spec.

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        # Initialize the iterator with the data feed_dict
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])


来源:https://stackoverflow.com/questions/52266000/avoiding-tf-data-dataset-from-tensor-slices-with-estimator-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!