问题
I would like to create a number of tf.data.Dataset
using the from_generator()
function. I would like to send an argument to the generator function (raw_data_gen
). The idea is that the generator function will yield different data depending on the argument sent. In this way I would like raw_data_gen
to be able to provide either training, validation or test data.
training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))
validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))
test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))
The error message I get when I try to call from_generator()
in this way is:
TypeError: from_generator() got an unexpected keyword argument 'args'
Here is the raw_data_gen
function although I'm not sure if you will need this as my hunch is that the problem is with the call of from_generator()
:
def raw_data_gen(train_val_or_test):
if train_val_or_test == 1:
#For every filename collected in the list
for filename, lab in training_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
elif train_val_or_test == 2:
#For every filename collected in the list
for filename, lab in validation_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
elif train_val_or_test == 3:
#For every filename collected in the list
for filename, lab in test_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
else:
print("generator function called with an argument not in [1, 2, 3]")
raise ValueError()
回答1:
You need to define a new function based on raw_data_gen
that doesn't take any arguments. You can use the lambda
keyword to do this.
training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...
Now, we are passing a function to from_generator
that doesn't take any arguments, but that will simply act as raw_data_gen
with the argument set to 1. You can use the same scheme for the validation and test sets, passing 2 and 3 respectively.
回答2:
For Tensorflow 2.4:
training_dataset = tf.data.Dataset.from_generator(
raw_data_gen,
args=(1),
output_types=(tf.float32, tf.uint8),
output_shapes=([None, 1], [None]))
来源:https://stackoverflow.com/questions/52443273/how-do-you-send-arguments-to-a-generator-function-using-tf-data-dataset-from-gen