问题
While adding the .cache()
step to my dataset pipeline, successives training epochs still download the data from the network storage.
I have a dataset on a network storage. I want to cache it, but not to repeat it: a training epoch must run through the whole dataset. Here is my dataset building pipeline:
return tf.data.Dataset.list_files(
file_pattern
).interleave(
tf.data.TFRecordDataset,
num_parallel_calls=tf.data.experimental.AUTOTUNE
).shuffle(
buffer_size=2048
).batch(
batch_size=2048,
drop_remainder=True,
).cache(
).map(
map_func=_parse_example_batch,
num_parallel_calls=tf.data.experimental.AUTOTUNE
).prefetch(
buffer_size=32
)
If I use it as is, the dataset is downloaded at each epoch. To avoid this, I have to add the .repeat()
step to the pipeline and use the steps_per_epoch
keyword of the model.fit
function. However, I do not know the size of the complete dataset and thus I cannot pass the right steps_per_epoch
value.
What is the right way to cache and use dataset of unknown size?
Thanks.
Edit
While reading some TF code, I (re)discovered the make_initializable_iterator. It seems that it is what I am looking for, that is to say iterate multiple times through the same dataset (taking advantage of the cache after the first iteration). However, this is deprecated and no longer part of the main API in TF2.
Updating instruction is to manually iterate over the Dataset with for ... in dataset
. Is it not what is done by the keras.Model.fit
function? Have I to write the training loop manually to get cache advantages?
Kind.
回答1:
In TF2.0, you do not need .repeat()
. By
successives training epochs still download the data from the network storage.
I think you got confused with the message filling up shuffle buffer
. This happens before every epoch if you are using shuffle()
function. Maybe try without shuffle()
, just to see the difference.
Also, I would suggest you to use cache()
after map()
and before batch()
.
EDIT
filling up shuffle buffer
is a message you get when using shuffle
function. You can still shuffle()
the dataset after using cache()
. Look here
Also, if I understood it correctly you are feeding the resulted dataset from map()
to your model for training, then you should cache()
this dataset not the other one because training will be done on this.
For counting the number of elements in your dataset you can use following code
num_elements = 0
for element in dataset: # tf.dataset type
num_elements += 1
print ('Total number of elements in the file: ',num_elements)
Now, by diving this num_elements
with your batch_size
you would get steps_per_epoch
回答2:
Good news! Final v2.0.0 release fix this behavior.
Here is a code snippet to highlight the different behaviors.
import time
import tensorflow as tf
import tensorflow.keras as keras
# Simple layer that just print its inputs
class Print(keras.layers.Layer):
def compute_output_signature(self, input_signature):
return input_signature
def call(self, inputs, **kwargs):
tf.print(inputs)
return inputs
# Generator returning incremented values each time it is re-initialized
generator_list = [0]
def generator():
v = generator_list[-1]
generator_list.append(v+1)
tf.print("Generating samples with value {}".format(v))
time.sleep(2)
for i in range(2):
yield (tf.constant([v]), tf.constant(v))
def main():
model_input = keras.layers.Input(shape=(1,))
model_output = Print()(model_input)
model = keras.Model(inputs=model_input, outputs=model_output)
model.compile("adam", loss="mae")
ds = tf.data.Dataset.from_generator(
generator, (tf.int64, tf.int64), ([1], [])
)
cached_ds = ds.cache()
tf.print("Fit")
model.fit(
cached_ds,
epochs=3,
verbose=2
)
tf.print("For ... in ...")
for i in range(3):
for x, y in cached_ds:
model(x)
if __name__ == '__main__':
main()
With tensorflow 2.0.0-b1 (used on Google AI Platform), here is the output:
Fit
Epoch 1/3
Generating samples with value 0
# sleep 2s
2019-10-03 15:45:32.718522: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1483] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU. To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
[[0]]
[[0]]
2/2 - 2s - loss: 0.0000e+00
Generating samples with value 1
# sleep 2s
Epoch 2/3
[[1]]
[[1]]
2/2 - 2s - loss: 0.0000e+00
Epoch 3/3
2019-10-03 15:45:34.774195: W tensorflow/core/kernels/data/cache_dataset_ops.cc:815] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Generating samples with value 2
# sleep 2s
[[2]]
[[2]]
2019-10-03 15:45:36.782046: W tensorflow/core/kernels/data/cache_dataset_ops.cc:815] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2/2 - 2s - loss: 0.0000e+00
For ... in ...
Generating samples with value 3
# sleep 2s
[3]
[3]
Generating samples with value 4
# sleep 2s
[4]
[4]
Generating samples with value 5
# sleep 2s
[5]
[5]
You can see, that the value of the tensor is incremented for each epoch, and the sleep instruction is executed each time. Moreover, we get the warning about truncated iterator...
Now, with tensorflow 2.0.0:
Fit
Epoch 1/3
WARNING:tensorflow:The list of trainable weights is empty. Make sure that you are not setting model.trainable to False before compiling the model.
Generating samples with value 0
# sleep 2s
[[0]]
[[0]]
2019-10-03 15:49:59.587796: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
2/2 - 2s - loss: 0.0000e+00
Epoch 2/3
[[0]]
[[0]]
2019-10-03 15:49:59.598144: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
2/2 - 0s - loss: 0.0000e+00
Epoch 3/3
[[0]]
[[0]]
2019-10-03 15:49:59.605260: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
For ... in ...
2/2 - 0s - loss: 0.0000e+00
[0]
[0]
[0]
[0]
[0]
[0]
And 'Voila'! The generator function is executed only once, with no more sleeps and always the same value of the tensor. I just have some warnings about end of sequence, but I can support it!
Kind.
来源:https://stackoverflow.com/questions/57977408/how-to-cache-and-iterate-through-a-dataset-of-unknown-size