问题
I have the following input_fn.
def input_fn(filenames, batch_size):
# Create a dataset containing the text lines.
dataset = tf.data.TextLineDataset(filenames).skip(1)
# Parse each line.
dataset = dataset.map(_parse_line)
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(10000).repeat().batch(batch_size)
# Return the dataset.
return dataset
It works great if filenames=['file1.csv']
or filenames=['file2.csv']
. It gives me an error if filenames=['file1.csv', 'file2.csv']
. In Tensorflow documentation it says filenames
is a tf.string
tensor containing one or more filenames. How should I import multiple files?
Following is the error. It seems it's ignoring the .skip(1)
in the input_fn
above:
InvalidArgumentError: Field 0 in record 0 is not a valid int32: row_id
[[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_INT32, DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4, DecodeCSV/record_defaults_5, DecodeCSV/record_defaults_6, DecodeCSV/record_defaults_7, DecodeCSV/record_defaults_8, DecodeCSV/record_defaults_9, DecodeCSV/record_defaults_10, DecodeCSV/record_defaults_11, DecodeCSV/record_defaults_12, DecodeCSV/record_defaults_13, DecodeCSV/record_defaults_14, DecodeCSV/record_defaults_15, DecodeCSV/record_defaults_16, DecodeCSV/record_defaults_17, DecodeCSV/record_defaults_18)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?], [?], ..., [?], [?], [?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT32, DT_STRING, DT_STRING, ..., DT_INT32, DT_FLOAT, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]
回答1:
You have the right idea using tf.data.TextLineDataset
. What your current implementation does however, is yield every line of every file in its input tensor of filenames except the first one of the first file. The way you are skipping the first line now only affects the very first line in the very first file. In the second file, the first line is not skipped.
Based on the example on the Datasets guide, you should adapt your code to first create a regular Dataset from the filenames, then run flat_map on each filename to read it using TextLineDataset
, simultaneously skipping the first row:
d = tf.data.Dataset.from_tensor_slices(filenames)
# get dataset from each file, skipping first line of each file
d = d.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1))
d = d.map(_parse_line) # And whatever else you need to do
Here, flat_map
creates a new dataset from every element of the original dataset by reading the contents of the file and skipping the first line.
来源:https://stackoverflow.com/questions/49180980/input-multiple-files-into-tensorflow-dataset