I was confused by this problem for several days...
My question is that why the training time has such massive difference between that I set the batch_size to be \"1\
You should also take into account the following function parameters when working with fit_generator
:
max_queue_size
, use_multiprocessing
and workers
max_queue_size
- might cause to load more data than you actually expect, which depending on your generator code may do something unexpected or unnecessary which can slow down your execution times.
use_multiprocessing
together with workers
- might spin-up additional processes that would lead to additional work for serialization and interprocess communication. First you would get your data serialized using pickle, then you would send your data to that target processes, then you would do your processing inside those processes and then the whole communication procedure repeats backwards, you pickle results, and send them to the main process via RPC. In most cases it should be fast, but if you're processing dozens of gigabytes of data or have your generator implemented in sub-optimal fashion then you might get the slowdown you describe.
When you use fit_generator, the number of samples processed for each epoch is batch_size * steps_per_epochs. From the Keras documentation for fit_generator: https://keras.io/models/sequential/
steps_per_epoch: Total number of steps (batches of samples) to yield from generator before declaring one epoch finished and starting the next epoch. It should typically be equal to the number of unique samples of your dataset divided by the batch size.
This is different from the behaviour of 'fit', where increasing batch_size typically speeds up things.
In conclusion, when you increase batch_size with fit_generator, you should decrease steps_per_epochs by the same factor, if you want training time to stay the same or lower.
The whole thing is:
fit() works faster than fit_generator() since it can access data directly in memory.
fit()
takes numpy arrays data into memory, while fit_generator()
takes data from the sequence generator such as keras.utils.Sequence
which works slower.