How to use repeat() function when building data in Keras?

后端 未结 1 1960
迷失自我
迷失自我 2021-01-04 22:58

I am training a binary classifier on a dataset of cats and dogs:
Total Dataset: 10000 images
Training Dataset: 8000 images
Validation/Test Dataset: 2000 im

相关标签:
1条回答
  • 2021-01-04 23:06

    Your problem stems from the fact that the parameters steps_per_epoch and validation_steps need to be equal to the total number of data points divided to the batch_size.

    Your code would work in Keras 1.X, prior to August 2017.

    Change your model.fit function to:

    history = model.fit_generator(training_set,
                                  steps_per_epoch=int(8000/batch_size),
                                  epochs=25,
                                  validation_data=test_set,
                                  validation_steps=int(2000/batch_size))
    

    As of TensorFlow2.1, fit_generator() is being deprecated. You can use .fit() method also on generators.

    TensorFlow >= 2.1 code:

    history = model.fit(training_set.repeat(),
                        steps_per_epoch=int(8000/batch_size),
                        epochs=25,
                        validation_data=test_set.repeat(),
                        validation_steps=int(2000/batch_size))
    

    Notice that int(8000/batch_size) is equivalent to 8000 // batch_size (integer division)

    0 讨论(0)
提交回复
热议问题