The established way to use TF Dataset API in Keras is to feed `model.fit` with `make_one_shot_iterator()`, But this iterator only good for one Epoch

跟風遠走 提交于 2019-12-13 03:49:55

问题


Edit:

To clarify why this question is different from the suggested duplicates, this SO question follows up on those suggested duplicates, on what exactly is Keras doing with the techniques described in those SO questions. The suggested duplicates specify using a dataset API make_one_shot_iterator() in model.fit, my follow up is that make_one_shot_iterator() can only go through the dataset once, however in the solutions given, several epochs are specified.


This is a follow up to these SO questions

How to Properly Combine TensorFlow's Dataset API and Keras?

Tensorflow keras with tf dataset input

Using tf.data.Dataset as training input to Keras model NOT working

Where "Starting from Tensorflow 1.9, one can pass tf.data.Dataset object directly into keras.Model.fit() and it would act similar to fit_generator". Each example has a TF dataset one shot iterator fed into Kera's model.fit.

An example is given below

# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)

model = # your keras model here              
model.fit(
    training_set.make_one_shot_iterator(),
    steps_per_epoch=len(x_train) // 128,
    epochs=5,
    verbose = 1)

However, according the the Tensorflow Dataset API guide (here https://www.tensorflow.org/guide/datasets ) :

A one-shot iterator is the simplest form of iterator, which only supports iterating once through a dataset

So it's only good for 1 epoch. However, the codes in the SO questions specify several epochs, with the code example above specifying 5 epochs.

Is there any explanation for this contradiction? Does Keras somehow know that when the one shot iterator has gone through the dataset, it can re-initialize and shuffle the data?


回答1:


You can simply pass dataset object to model.fit, Keras will handle iteration. Considering one of pre-made datasets:

train, test = tf.keras.datasets.cifar10.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))

This will create dataset object from training data of cifar10 dataset. In this case parse function isn't needed. If you create dataset from path containing images of list of numpy arrays you'll need one.

dataset = tf.data.Dataset.from_tensor_slices((image_path, labels_path)) 

In case you'll need a function to load actual data from filename. Numpy array can be handled the same way just without tf.read_file

def parse_func(filename):
    f = tf.read_file(filename)
    image = tf.image.decode_image(f)
    label = #get label from filename
    return image, label

Then you can shuffle, batch, and map any parse function to this dataset. You can control how many examples will be preloaded with shuffle buffer. Repeat controls epoch count and better be left None, so it will repeat indefinitely. You can use either plain batch function or combine with

dataset = dataset.shuffle().repeat()
dataset.apply(tf.data.experimental.map_and_batch(map_func=parse_func, batch_size,num_parallel_batches))

Then dataset object can be passed to model.fit model.fit(dataset, epochs, steps_per_epoch). Note that steps_per_epoch is a necessary parameter in this case, it will define when to start new epoch. So you'll have to know epoch size in advance.



来源:https://stackoverflow.com/questions/55444615/the-established-way-to-use-tf-dataset-api-in-keras-is-to-feed-model-fit-with

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!