Is it possible for obtain the total number of records from a .tfrecords
file ? Related to this, how does one generally keep track of the number of epochs that have elapsed while training models? While it is possible for us to specify the batch_size
and num_of_epochs
, I am not sure if it is straightforward to obtain values such as current epoch
, number of batches per epoch etc - just so that I could have more control of how the training is progressing. Currently, I'm just using a dirty hack to compute this as I know before hand how many records there are in my .tfrecords file and the size of my minibatches. Appreciate any help..
To count the number of records, you should be able to use tf.python_io.tf_record_iterator.
c = 0
for fn in tf_records_filenames:
for record in tf.python_io.tf_record_iterator(fn):
c += 1
To just keep track of the model training, tensorboard comes in handy.
No it is not possible. TFRecord does not store any metadata about the data being stored inside. This file
represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.
If you want, you can store this metadata manually or use a record_iterator to get the number (you will need to iterate through all the records that you have:
sum(1 for _ in tf.python_io.tf_record_iterator(file_name))
If you want to know the current epoch, you can do this either from tensorboard or by printing the number from the loop.
As per the deprecation warning on tf_record_iterator, we can also use eager execution to count records.
#!/usr/bin/env python
from __future__ import print_function
import tensorflow as tf
import sys
assert len(sys.argv) == 2, \
"USAGE: {} <file_glob>".format(sys.argv[0])
input_pattern = sys.argv[1]
# Expand glob if there is one
input_files = tf.io.gfile.glob(input_pattern)
# Create the dataset
data_set = tf.data.TFRecordDataset(input_files)
# Count the records
records_n = sum(1 for record in data_set)
print("records_n = {}".format(records_n))
As tf.io.tf_record_iterator is being deprecated, the great answer of Salvador Dali should now read
sum(1 for _ in tf.data.TFRecordDataset(file_name))