How to pad to fixed BATCH_SIZE in tf.data.Dataset?

心已入冬 提交于 2019-12-23 13:07:05

问题


I have a dataset with 11 samples. And when I choose the BATCH_SIZE be 2, the following code will have errors:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

The problem lies in dataset = dataset.batch(batch_size), when the Dataset looped into the last batch, the remaining count of samples is just 1, so is there any way to pick randomly one from the previous visited samples and generate the last batch?


回答1:


@mining proposes a solution by padding the filenames.

Another solution is to use tf.contrib.data.batch_and_drop_remainder. This will batch the data with a fixed batch size and drop the last smaller batch.

In your examples, with 11 inputs and a batch size of 2, this would yield 5 batches of 2 elements.

Here is the example from the documentation:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))



回答2:


You can just set drop_remainder=True in your call to batch.

dataset = dataset.batch(batch_size, drop_remainder=True)

From the documentation:

drop_remainder: (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case its has fewer than batch_size elements; the default behavior is not to drop the smaller batch.



来源:https://stackoverflow.com/questions/48325636/how-to-pad-to-fixed-batch-size-in-tf-data-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!