问题
When I load a dataset, I wonder if there is any quick way to find the number of samples or batches in that dataset. I know that if I load a dataset with with_info=True
, I can see for example total_num_examples=6000,
but this information is not available if I split a dataset.
Currently, I count the number of samples as follows, but wondering if there is any better solution:
train_subsplit_1, train_subsplit_2, train_subsplit_3 = tfds.Split.TRAIN.subsplit(3)
cifar10_trainsub3 = tfds.load("cifar10", split=train_subsplit_3)
cifar10_trainsub3 = cifar10_trainsub3.batch(1000)
n = 0
for i, batch in enumerate(cifar10_trainsub3.take(-1)):
print(i, n, batch['image'].shape)
n += len(batch['image'])
print(i, n)
回答1:
If it's possible to know the length then you could use:
tf.data.experimental.cardinality(dataset)
but the problem is that a TF dataset is inherently lazily loaded. So we might not know the size of the dataset up front. Indeed, it's perfectly possible to have a dataset represent an infinite set of data!
If it is a small enough dataset you could also just iterate over it to get the length. I've used the following ugly little construct before but it depends on the dataset being small enough for us to be happy to load into memory and it's really not an improvement over your for
loop above!
dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1
来源:https://stackoverflow.com/questions/56369900/in-tensorflow-2-0-how-can-i-see-the-number-of-elements-in-a-dataset