问题
I want to have different number of samples on every epoch. For example at epoch 1 I want to have 100 samples (all samples) and at the second epoch I want only 50 samples. Right now, I'm doing this using tf.data.Dataset filter method. I'm using models/official/resnet tf code and using multi gpus.
My problem: after random number of epochs, the program hangs and even CTRL+C cannot kill the program. My questions is: Would the different number of samples per epoch cause any problem?
my filter's predicate function returns true or false based on one condition applied for each sample. I'm wondering if some hooks or other things depend on number of training samples at the beginning and when the number of samples is decreased, they wait for more sampels which is not available and probably that's why my program hangs (no gpu utilization).
I create my dataset as the following:
dataset = dataset.prefetch(buffer_size=batch_size)
dataset = dataset.shuffle()
dataset = dataset.repeat(1)
dataset = dataset.map(parse samples)
dataset = dataset.filter(my predicate on each samples)
dataset = dataset.batch()
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
For training I'm using tf.estimator.
来源:https://stackoverflow.com/questions/54221770/what-happens-if-number-of-samples-changes-every-epoch-using-tf-data-dataset-filt