Best way to process terabytes of data on gcloud ml-engine with keras

送分小仙女□ 提交于 2019-12-01 10:59:37

If you are willing to use tf.keras instead of actual Keras, you can instantiate a TFRecordDataset with the tf.data API and pass that directly to model.fit(). Bonus: you get to stream directly from Google Cloud storage, no need to download the data first:

# Construct a TFRecordDataset
ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
ds_train = ds_train.shuffle(1000).batch(32)

model.fit(ds_train)

To include validation data, create a TFRecordDataset with your validation TFRecords and pass that one to the validation_data argument of model.fit(). Note: this is possible as of TensorFlow 1.9.

Final note: you'll need to specify the steps_per_epoch argument. A hack that I use to know the total number of examples in all TFRecordfiles, is to simply iterate over the files and count:

import tensorflow as tf

def n_records(record_list):
    """Get the total number of records in a collection of TFRecords.
    Since a TFRecord file is intended to act as a stream of data,
    this needs to be done naively by iterating over the file and counting.
    See https://stackoverflow.com/questions/40472139

    Args:
        record_list (list): list of GCS paths to TFRecords files
    """
    counter = 0
    for f in record_list:
        counter +=\
            sum(1 for _ in tf.python_io.tf_record_iterator(f))
    return counter 

Which you can use to compute steps_per_epoch:

n_train = n_records([gs://path-to-tfrecords/record1,
                     gs://path-to-tfrecords/record2])

steps_per_epoch = n_train // batch_size
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!