I\'m trying to create a dataset from a CSV file with 784-bit long rows. Here\'s my code:
import tensorflow as tf
f = open(\"test.csv\", \"r\")
csvreader = c
From the docs, which you linked:
The
generator
argument must be a callable object that returns an object that support theiter()
protocol (e.g. a generator function)
This means you should be able to do something like this:
import tensorflow as tf
import csv
with open("test.csv", "r") as f:
csvreader = csv.reader(f)
gen = lambda: (row for row in csvreader)
ds = tf.data.Dataset()
ds.from_generator(gen, [tf.uint8]*28**2)
In other words, the function you pass must produce a generator when called. This is easy to achieve when making it an anonymous function (a lambda
).
Alternatively try this, which is closer to how it is done in the docs:
import tensorflow as tf
import csv
def read_csv(file_name="test.csv"):
with open(file_name) as f:
reader = csv.reader(f)
for row in reader:
yield row
ds = tf.data.Dataset.from_generator(read_csv, [tf.uint8]*28**2)
(If you need a different file name than whatever default you set, you can use functools.partial(read_csv, file_name="whatever.csv")
.)
The difference is that the read_csv
function returns the generator object when called, whereas what you constructed is already the generator object and equivalent to doing:
gen = read_csv()
ds = tf.data.Dataset.from_generator(gen, [tf.uint8]*28**2) # does not work