问题
Is it possible for obtain the total number of records from a .tfrecords
file ? Related to this, how does one generally keep track of the number of epochs that have elapsed while training models? While it is possible for us to specify the batch_size
and num_of_epochs
, I am not sure if it is straightforward to obtain values such as current epoch
, number of batches per epoch etc - just so that I could have more control of how the training is progressing. Currently, I'm just using a dirty hack to compute this as I know before hand how many records there are in my .tfrecords file and the size of my minibatches. Appreciate any help..
回答1:
To count the number of records, you should be able to use tf.python_io.tf_record_iterator.
c = 0
for fn in tf_records_filenames:
for record in tf.python_io.tf_record_iterator(fn):
c += 1
To just keep track of the model training, tensorboard comes in handy.
回答2:
No it is not possible. TFRecord does not store any metadata about the data being stored inside. This file
represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.
If you want, you can store this metadata manually or use a record_iterator to get the number (you will need to iterate through all the records that you have:
sum(1 for _ in tf.python_io.tf_record_iterator(file_name))
If you want to know the current epoch, you can do this either from tensorboard or by printing the number from the loop.
回答3:
As per the deprecation warning on tf_record_iterator, we can also use eager execution to count records.
#!/usr/bin/env python
from __future__ import print_function
import tensorflow as tf
import sys
assert len(sys.argv) == 2, \
"USAGE: {} <file_glob>".format(sys.argv[0])
tf.enable_eager_execution()
input_pattern = sys.argv[1]
# Expand glob if there is one
input_files = tf.io.gfile.glob(input_pattern)
# Create the dataset
data_set = tf.data.TFRecordDataset(input_files)
# Count the records
records_n = sum(1 for record in data_set)
print("records_n = {}".format(records_n))
回答4:
As tf.io.tf_record_iterator is being deprecated, the great answer of Salvador Dali should now read
tf.enable_eager_execution()
sum(1 for _ in tf.data.TFRecordDataset(file_name))
来源:https://stackoverflow.com/questions/40472139/obtaining-total-number-of-records-from-tfrecords-file-in-tensorflow