Obtaining total number of records from .tfrecords file in Tensorflow

ぃ、小莉子 提交于 2019-11-30 06:27:26

问题


Is it possible for obtain the total number of records from a .tfrecords file ? Related to this, how does one generally keep track of the number of epochs that have elapsed while training models? While it is possible for us to specify the batch_size and num_of_epochs, I am not sure if it is straightforward to obtain values such as current epoch, number of batches per epoch etc - just so that I could have more control of how the training is progressing. Currently, I'm just using a dirty hack to compute this as I know before hand how many records there are in my .tfrecords file and the size of my minibatches. Appreciate any help..


回答1:


To count the number of records, you should be able to use tf.python_io.tf_record_iterator.

c = 0
for fn in tf_records_filenames:
  for record in tf.python_io.tf_record_iterator(fn):
     c += 1

To just keep track of the model training, tensorboard comes in handy.




回答2:


No it is not possible. TFRecord does not store any metadata about the data being stored inside. This file

represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.

If you want, you can store this metadata manually or use a record_iterator to get the number (you will need to iterate through all the records that you have:

sum(1 for _ in tf.python_io.tf_record_iterator(file_name))

If you want to know the current epoch, you can do this either from tensorboard or by printing the number from the loop.




回答3:


As per the deprecation warning on tf_record_iterator, we can also use eager execution to count records.

#!/usr/bin/env python
from __future__ import print_function

import tensorflow as tf
import sys

assert len(sys.argv) == 2, \
    "USAGE: {} <file_glob>".format(sys.argv[0])

tf.enable_eager_execution()

input_pattern = sys.argv[1]

# Expand glob if there is one
input_files = tf.io.gfile.glob(input_pattern)

# Create the dataset
data_set = tf.data.TFRecordDataset(input_files)

# Count the records
records_n = sum(1 for record in data_set)

print("records_n = {}".format(records_n))



回答4:


As tf.io.tf_record_iterator is being deprecated, the great answer of Salvador Dali should now read

tf.enable_eager_execution()
sum(1 for _ in tf.data.TFRecordDataset(file_name))


来源:https://stackoverflow.com/questions/40472139/obtaining-total-number-of-records-from-tfrecords-file-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!