I want to train a model on about 2TB of image data on gcloud storage. I saved the image data as separate tfrecords and tried to use the tensorflow data api following this ex
If you are willing to use tf.keras
instead of actual Keras, you can instantiate a TFRecordDataset
with the tf.data
API and pass that directly to model.fit()
. Bonus: you get to stream directly from Google Cloud storage, no need to download the data first:
# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)
model.fit(ds_train)
To include validation data, create a TFRecordDataset
with your validation TFRecords and pass that one to the validation_data
argument of model.fit()
. Note: this is possible as of TensorFlow 1.9.
Final note: you'll need to specify the steps_per_epoch
argument. A hack that I use to know the total number of examples in all TFRecordfiles, is to simply iterate over the files and count:
import tensorflow as tf
def n_records(record_list):
"""Get the total number of records in a collection of TFRecords.
Since a TFRecord file is intended to act as a stream of data,
this needs to be done naively by iterating over the file and counting.
See https://stackoverflow.com/questions/40472139
Args:
record_list (list): list of GCS paths to TFRecords files
"""
counter = 0
for f in record_list:
counter +=\
sum(1 for _ in tf.python_io.tf_record_iterator(f))
return counter
Which you can use to compute steps_per_epoch
:
n_train = n_records([gs://path-to-tfrecords/record1,
gs://path-to-tfrecords/record2])
steps_per_epoch = n_train // batch_size