I\'m trying to create a dataset from a CSV file with 784-bit long rows. Here\'s my code:
import tensorflow as tf
f = open(\"test.csv\", \"r\")
csvreader = c
The generator
argument (perhaps confusingly) should not actually be a generator, but a callable returning an iterable (for example, a generator function). Probably the easiest option here is to use a lambda
. Also, a couple of errors: 1) tf.data.Dataset.from_generator is meant to be called as a class factory method, not from an instance 2) the function (like a few other in TensorFlow) is weirdly picky about parameters, and it wants you to give the sequence of dtypes and each data row as tuple
s (instead of the list
s returned by the CSV reader), you can use for example map
for that:
import csv
import tensorflow as tf
with open("test.csv", "r") as f:
csvreader = csv.reader(f)
ds = tf.data.Dataset.from_generator(lambda: map(tuple, csvreader),
(tf.uint8,) * (28 ** 2))