问题
I would like to manage my training with a tf.estimator.Estimator but have some trouble to use it alongside the tf.data API.
I have something like this:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
As I can't use a make_one_shot_iterator
for my use case, my issue is that input_fn
contains an iterator that should be initialized within model_fn
(here, I use tf.train.Scaffold to initialize local ops).
Also, I understood that we can't only use input_fn = iterator.get_next
otherwise the other ops will not be added to the same graph.
What is the recommended way to initialize the iterator?
回答1:
As of TensorFlow 1.5, it is possible to make input_fn
return a tf.data.Dataset
, e.g.:
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
return dataset
See c294fcfd.
For previous versions, you can add the iterator's initializer in the tf.GraphKeys.TABLE_INITIALIZERS
collections and rely on the default initializer.
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
来源:https://stackoverflow.com/questions/45011724/how-to-use-tf-datas-initializable-iterators-within-a-tf-estimators-input-fn