问题
When I train the model using the .fit()
layer there is the argument shuffle preset to True.
Let's say that my dataset has 100 samples and that the batch size is 10. When I set shuffle = True
then keras first randomly selects randomly the samples (now the 100 samples have a different order) and on the new order it will start creating the batches: batch 1: 1-10, batch 2: 11-20 etc.
If I set shuffle = 'batch'
how is it supposed to work in the background? Intuitively and using the previous example of 100 samples dataset with batch size = 10 my guess would be that keras first allocates the samples to the batches (i.e. batch 1: samples 1-10 following the dataset original order, batch 2: 11-20 following the dataset original order as well, batch 3 ... so on so forth) and then shuffles the order of the batches. So the model now will be trained on the randomly ordered batches say for example: 3 (contains samples 21 - 30), 4 (contains samples 31 - 40), 7 (contains samples 61 - 70), 1 (contains samples 1 - 10), ... (I made up the order of the batches).
Is my thinking right or am I missing something?
Thanks!
回答1:
Looking at the implementation at this link (line 349 of training.py) the answer seems to be positive.
Try this code for checking:
import numpy as np
def batch_shuffle(index_array, batch_size):
"""Shuffles an array in a batch-wise fashion.
Useful for shuffling HDF5 arrays
(where one cannot access arbitrary indices).
# Arguments
index_array: array of indices to be shuffled.
batch_size: integer.
# Returns
The `index_array` array, shuffled in a batch-wise fashion.
"""
batch_count = int(len(index_array) / batch_size)
# to reshape we need to be cleanly divisible by batch size
# we stash extra items and reappend them after shuffling
last_batch = index_array[batch_count * batch_size:]
index_array = index_array[:batch_count * batch_size]
index_array = index_array.reshape((batch_count, batch_size))
np.random.shuffle(index_array)
index_array = index_array.flatten()
return np.append(index_array, last_batch)
x = np.array(range(100))
x_s = batch_shuffle(x,10)
来源:https://stackoverflow.com/questions/45567692/how-does-shuffle-batch-argument-of-the-fit-layer-work-in-the-background