Input multiple files into Tensorflow dataset

吃可爱长大的小学妹 提交于 2019-12-05 02:56:40

问题


I have the following input_fn.

def input_fn(filenames, batch_size):
    # Create a dataset containing the text lines.
    dataset = tf.data.TextLineDataset(filenames).skip(1)

    # Parse each line.
    dataset = dataset.map(_parse_line)

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(10000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset

It works great if filenames=['file1.csv'] or filenames=['file2.csv']. It gives me an error if filenames=['file1.csv', 'file2.csv']. In Tensorflow documentation it says filenames is a tf.string tensor containing one or more filenames. How should I import multiple files?

Following is the error. It seems it's ignoring the .skip(1) in the input_fn above:

InvalidArgumentError: Field 0 in record 0 is not a valid int32: row_id
 [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_INT32, DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4, DecodeCSV/record_defaults_5, DecodeCSV/record_defaults_6, DecodeCSV/record_defaults_7, DecodeCSV/record_defaults_8, DecodeCSV/record_defaults_9, DecodeCSV/record_defaults_10, DecodeCSV/record_defaults_11, DecodeCSV/record_defaults_12, DecodeCSV/record_defaults_13, DecodeCSV/record_defaults_14, DecodeCSV/record_defaults_15, DecodeCSV/record_defaults_16, DecodeCSV/record_defaults_17, DecodeCSV/record_defaults_18)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?], [?], ..., [?], [?], [?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT32, DT_STRING, DT_STRING, ..., DT_INT32, DT_FLOAT, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]

回答1:


You have the right idea using tf.data.TextLineDataset. What your current implementation does however, is yield every line of every file in its input tensor of filenames except the first one of the first file. The way you are skipping the first line now only affects the very first line in the very first file. In the second file, the first line is not skipped.

Based on the example on the Datasets guide, you should adapt your code to first create a regular Dataset from the filenames, then run flat_map on each filename to read it using TextLineDataset, simultaneously skipping the first row:

d = tf.data.Dataset.from_tensor_slices(filenames) 
# get dataset from each file, skipping first line of each file
d = d.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1))
d = d.map(_parse_line) # And whatever else you need to do

Here, flat_map creates a new dataset from every element of the original dataset by reading the contents of the file and skipping the first line.



来源:https://stackoverflow.com/questions/49180980/input-multiple-files-into-tensorflow-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!