问题
I have a dataset with 11 samples. And when I choose the BATCH_SIZE
be 2, the following code will have errors:
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
if shuffle:
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)
The problem lies in dataset = dataset.batch(batch_size)
, when the Dataset
looped into the last batch, the remaining count of samples is just 1, so is there any way to pick randomly one from the previous visited samples and generate the last batch?
回答1:
@mining proposes a solution by padding the filenames.
Another solution is to use tf.contrib.data.batch_and_drop_remainder. This will batch the data with a fixed batch size and drop the last smaller batch.
In your examples, with 11 inputs and a batch size of 2, this would yield 5 batches of 2 elements.
Here is the example from the documentation:
dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
回答2:
You can just set drop_remainder=True
in your call to batch
.
dataset = dataset.batch(batch_size, drop_remainder=True)
From the documentation:
drop_remainder: (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case its has fewer than batch_size elements; the default behavior is not to drop the smaller batch.
来源:https://stackoverflow.com/questions/48325636/how-to-pad-to-fixed-batch-size-in-tf-data-dataset