问题
I'm am trying to figure out the recommended way to use the dataset
api together with the estimator
api. Everything I have seen online is some variation of this:
def train_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
return dataset
which can then be passed to the estimator's train function:
classifier.train(
input_fn=train_input_fn,
#...
)
but the dataset guide warns that:
the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.
and then describes a method that involves defining placeholders which are then filled with the feed_dict
:
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
But if you're using the estimator
api, you're not manually running the session. So how do you use the dataset
api with estimators while avoiding the problems associated with from_tensor_slices()
?
回答1:
To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook, which has access to the session at multiple times during training and evaluation steps.
You can then use this new class to initialize the iterator has you would normally do in a classic setting. You simply need to pass this newly created hook to the training/evaluation functions or to the correct train spec.
Here is quick example that you can adapt to your needs :
class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn
def after_create_session(self, session, coord):
# Initialize the iterator with the data feed_dict
self.iterator_initializer_func(session)
def get_inputs(X, y):
iterator_initializer_hook = IteratorInitializerHook()
def input_fn():
X_pl = tf.placeholder(X.dtype, X.shape)
y_pl = tf.placeholder(y.dtype, y.shape)
dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
dataset = ...
...
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
feed_dict={X_pl: X, y_pl: y})
return next_example, next_label
return input_fn, iterator_initializer_hook
...
train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)
...
estimator.train(input_fn=train_input_fn,
hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
hooks=[test_iterator_initializer_hook])
来源:https://stackoverflow.com/questions/52266000/avoiding-tf-data-dataset-from-tensor-slices-with-estimator-api