I would like to create a number of tf.data.Dataset
using the from_generator()
function. I would like to send an argument to the generator function (
You need to define a new function based on raw_data_gen
that doesn't take any arguments. You can use the lambda
keyword to do this.
training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...
Now, we are passing a function to from_generator
that doesn't take any arguments, but that will simply act as raw_data_gen
with the argument set to 1. You can use the same scheme for the validation and test sets, passing 2 and 3 respectively.
For Tensorflow 2.4:
training_dataset = tf.data.Dataset.from_generator(
raw_data_gen,
args=(1),
output_types=(tf.float32, tf.uint8),
output_shapes=([None, 1], [None]))