问题
I am trying to use fit_generator
with a custom generator to read in data that's too big for memory. There are 1.25 million rows I want to train on, so I have the generator yield 50,000 rows at a time. fit_generator
has 25 steps_per_epoch
, which I thought would bring in those 1.25MM per epoch. I added a print statement so that I could see how much offset the process was doing, and I found that it exceeded the max when it got a few steps into epoch 2. There are a total of 1.75 million records in that file, and once it passes 10 steps, it gets an index error in the create_feature_matrix
call (because it brings in no rows).
def get_next_data_batch():
import gc
nrows = 50000
skiprows = 0
while True:
d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
print(skiprows)
x,y = create_feature_matrix(d)
yield x,y
skiprows = skiprows + nrows
gc.collect()
get_data = get_next_data_batch()
... set up a Keras NN ...
model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)
Am I using fit_generator wrong or is there some change that needs to be made to my custom generator to get this to work?
回答1:
No - fit_generator
doesn't reset generator, it's simply continuing calling it. In order to achieve the behavior you want you may try the following:
def get_next_data_batch(nb_of_calls_before_reset=25):
import gc
nrows = 50000
skiprows = 0
nb_calls = 0
while True:
d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
print(skiprows)
x,y = create_feature_matrix(d)
yield x,y
nb_calls += 1
if nb_calls == nb_of_calls_before_reset:
skiprows = 0
else:
skiprows = skiprows + nrows
gc.collect()
来源:https://stackoverflow.com/questions/48729107/is-fit-generator-in-keras-supposed-to-reset-the-generator-after-each-epoch