问题
I'm writing an input pipeline using tf.data.Dataset
. I'd like to use python code to load and transform my samples, the code returns a dictionary of tensors. Unfortunately I don't see how I can define that as the output type that is passed to tf.py_func
.
I have a workaround where my function returns list of tensors instead of a dictionary, but it makes my code less readable as I have 4 keys in that dict.
The code looks somehow as follows
file_list = ....
def load(file_name):
return {"image": np.zeros(...,dtype=np.float32),
"label": 1.0} # there is more labels, in the original code
ds = tf.data.Dataset.from_tensor_slices(file_list)
ds.shuffle(...)
out_type = [{'image':tf.float32, "label":tf.float32 }] # ????
ds.map(lambda x: tf.py_func(load, [x], out_type))
ds.batch(...)
ds.prefetch(1)
来源:https://stackoverflow.com/questions/48986874/how-to-use-py-func-with-a-function-that-returns-dict