In TensorFlow 2.0, how can I see the number of elements in a dataset?

邮差的信 提交于 2020-06-16 13:03:21

问题


When I load a dataset, I wonder if there is any quick way to find the number of samples or batches in that dataset. I know that if I load a dataset with with_info=True, I can see for example total_num_examples=6000, but this information is not available if I split a dataset.

Currently, I count the number of samples as follows, but wondering if there is any better solution:

train_subsplit_1, train_subsplit_2, train_subsplit_3 = tfds.Split.TRAIN.subsplit(3)

cifar10_trainsub3 = tfds.load("cifar10", split=train_subsplit_3)

cifar10_trainsub3 = cifar10_trainsub3.batch(1000)

n = 0
for i, batch in enumerate(cifar10_trainsub3.take(-1)):
    print(i, n, batch['image'].shape)
    n += len(batch['image'])

print(i, n)

回答1:


If it's possible to know the length then you could use:

tf.data.experimental.cardinality(dataset)

but the problem is that a TF dataset is inherently lazily loaded. So we might not know the size of the dataset up front. Indeed, it's perfectly possible to have a dataset represent an infinite set of data!

If it is a small enough dataset you could also just iterate over it to get the length. I've used the following ugly little construct before but it depends on the dataset being small enough for us to be happy to load into memory and it's really not an improvement over your for loop above!

dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1


来源:https://stackoverflow.com/questions/56369900/in-tensorflow-2-0-how-can-i-see-the-number-of-elements-in-a-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!