问题
I had a question about the use of batch, repeat and shuffle with tf.Dataset.
It is not clear to me exactly how repeat and shuffle are used. I understand that .batch
will dictate how many training examples will undergo stochastic gradient descent, the uses of .repeat
and .shuffle
are still not clear to me.
First Question
Even after reviewing here and here, .repeat
is used to reiterate over the dataset once a tf.errors.OutOfRangeError
is thrown. Therefore, in my code does that mean I no longer have to implement:
try:
while True:
_ = sess.run(self.optimizer)
except tf.errors.OutOfRangeError:
pass
because .repeat
will automatically repeat the dataset once it is exhausted? When does it stop? or will it never stop and you just have to exit out of the while True loop once a certain number of batches (say 1000) have passed?
Second Question
Secondly, the use .shuffle
makes no sense to me. Does .shuffle.batch()
imply that I have, say, 100,000 samples, put 1000 randomly in a buffer with .shuffle
, then batch say, 100 of them with .batch()
. From my understanding the next batch will use 999 of those samples and place 1 new one in the buffer. So if my samples have no order to them, then .shuffle
should be avoided all together? And if .batch
is used, it would still batch 100 from those 999+1 in the buffer?
Third Question
And lastly, if I am using a separate td.dataset
object for testing, what order of .shuffle.batch()
should I consider? Right now I use:
sess.run(self.test_init)
try:
while True:
accuracy_batch = sess.run(self.accuracy)
except tf.errors.OutOfRangeError:
pass
With:
test_data = self.test_dataset.shuffle(self.batch_size).batch(self.batch_size)
I have over 110,000 training examples at my disposal, so self.batch_size will set the number of samples I want to use to test my accuracy. So, if I wanted to just test on the whole test dataset I wouldn't use .batch
? But since I have it iterating over the whole dataset with while True
, it makes no difference? With the use of .shuffle
I noticed my accuracies changed, but without it they were very similar. This makes me think .shuffle
is randomizing the batch and may be reusing training examples?
回答1:
First Question:
That's correct - if you feed a dataset you no longer need to catch the OutOfRangeError
.
repeat()
takes an optional argument for the number of times it should repeat. This means repeat(10)
will iterate over the entire dataset 10 times. If you choose to omit the argument then it will repeat indefinately
Second Question
Shuffle()
(if used) should be called before batch()
- we want to shuffle records not batches.
The buffer is first filled by adding your records in order then, once full, a random one is selected and emitted and a new record read from the original source.
If you have something like
ds.shuffle(1000).batch(100)
then in order to return a single batch, this last step is repeated 100 times (maintaining the buffer at 1000). Batching is a separate operation.
Third question
Generally we don't shuffle a test set at all - only the training set (We evaluate using the entire test set anyway, right? So why shuffle?).
So, if I wanted to just test on the whole test dataset I wouldn't use
.batch
Hmm - not so (at least not always). You would certainly need to use batch if your whole test dataset didnt fit into memory - a common occurrence. You would want to test the whole dataset but to run the numbers in manageable bites!
来源:https://stackoverflow.com/questions/56944856/tensorflow-dataset-questions-about-shuffle-batch-and-repeat