Best way to process terabytes of data on gcloud ml-engine with keras

后端 未结 1 1302
猫巷女王i
猫巷女王i 2021-01-15 12:18

I want to train a model on about 2TB of image data on gcloud storage. I saved the image data as separate tfrecords and tried to use the tensorflow data api following this ex

1条回答
  •  北海茫月
    2021-01-15 12:36

    If you are willing to use tf.keras instead of actual Keras, you can instantiate a TFRecordDataset with the tf.data API and pass that directly to model.fit(). Bonus: you get to stream directly from Google Cloud storage, no need to download the data first:

    # Construct a TFRecordDataset
    ds_train tf.data.TFRecordDataset('gs://') # path to TFRecords on GCS
    ds_train = ds_train.shuffle(1000).batch(32)
    
    model.fit(ds_train)
    

    To include validation data, create a TFRecordDataset with your validation TFRecords and pass that one to the validation_data argument of model.fit(). Note: this is possible as of TensorFlow 1.9.

    Final note: you'll need to specify the steps_per_epoch argument. A hack that I use to know the total number of examples in all TFRecordfiles, is to simply iterate over the files and count:

    import tensorflow as tf
    
    def n_records(record_list):
        """Get the total number of records in a collection of TFRecords.
        Since a TFRecord file is intended to act as a stream of data,
        this needs to be done naively by iterating over the file and counting.
        See https://stackoverflow.com/questions/40472139
    
        Args:
            record_list (list): list of GCS paths to TFRecords files
        """
        counter = 0
        for f in record_list:
            counter +=\
                sum(1 for _ in tf.python_io.tf_record_iterator(f))
        return counter 
    

    Which you can use to compute steps_per_epoch:

    n_train = n_records([gs://path-to-tfrecords/record1,
                         gs://path-to-tfrecords/record2])
    
    steps_per_epoch = n_train // batch_size
    

    0 讨论(0)
提交回复
热议问题