Obtaining total number of records from .tfrecords file in Tensorflow

后端 未结 4 1931
南方客
南方客 2020-12-13 10:27

Is it possible for obtain the total number of records from a .tfrecords file ? Related to this, how does one generally keep track of the number of epochs that h

相关标签:
4条回答
  • 2020-12-13 10:40

    As tf.io.tf_record_iterator is being deprecated, the great answer of Salvador Dali should now read

    tf.enable_eager_execution()
    sum(1 for _ in tf.data.TFRecordDataset(file_name))
    
    0 讨论(0)
  • 2020-12-13 10:41

    No it is not possible. TFRecord does not store any metadata about the data being stored inside. This file

    represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.

    If you want, you can store this metadata manually or use a record_iterator to get the number (you will need to iterate through all the records that you have:

    sum(1 for _ in tf.python_io.tf_record_iterator(file_name))
    

    If you want to know the current epoch, you can do this either from tensorboard or by printing the number from the loop.

    0 讨论(0)
  • 2020-12-13 10:46

    As per the deprecation warning on tf_record_iterator, we can also use eager execution to count records.

    #!/usr/bin/env python
    from __future__ import print_function
    
    import tensorflow as tf
    import sys
    
    assert len(sys.argv) == 2, \
        "USAGE: {} <file_glob>".format(sys.argv[0])
    
    tf.enable_eager_execution()
    
    input_pattern = sys.argv[1]
    
    # Expand glob if there is one
    input_files = tf.io.gfile.glob(input_pattern)
    
    # Create the dataset
    data_set = tf.data.TFRecordDataset(input_files)
    
    # Count the records
    records_n = sum(1 for record in data_set)
    
    print("records_n = {}".format(records_n))
    
    0 讨论(0)
  • 2020-12-13 11:03

    To count the number of records, you should be able to use tf.python_io.tf_record_iterator.

    c = 0
    for fn in tf_records_filenames:
      for record in tf.python_io.tf_record_iterator(fn):
         c += 1
    

    To just keep track of the model training, tensorboard comes in handy.

    0 讨论(0)
提交回复
热议问题