How to make a generator callable?

后端 未结 3 608
渐次进展
渐次进展 2021-01-04 10:12

I\'m trying to create a dataset from a CSV file with 784-bit long rows. Here\'s my code:

import tensorflow as tf

f = open(\"test.csv\", \"r\")
csvreader = c         


        
3条回答
  •  有刺的猬
    2021-01-04 10:21

    From the docs, which you linked:

    The generator argument must be a callable object that returns an object that support the iter() protocol (e.g. a generator function)

    This means you should be able to do something like this:

    import tensorflow as tf
    import csv
    
    with open("test.csv", "r") as f:
        csvreader = csv.reader(f)
        gen = lambda: (row for row in csvreader)
        ds = tf.data.Dataset()
        ds.from_generator(gen, [tf.uint8]*28**2)
    

    In other words, the function you pass must produce a generator when called. This is easy to achieve when making it an anonymous function (a lambda).

    Alternatively try this, which is closer to how it is done in the docs:

    import tensorflow as tf
    import csv
    
    
    def read_csv(file_name="test.csv"):
        with open(file_name) as f:
            reader = csv.reader(f)
            for row in reader:
                yield row
    
    ds = tf.data.Dataset.from_generator(read_csv, [tf.uint8]*28**2)
    

    (If you need a different file name than whatever default you set, you can use functools.partial(read_csv, file_name="whatever.csv").)

    The difference is that the read_csv function returns the generator object when called, whereas what you constructed is already the generator object and equivalent to doing:

    gen = read_csv()
    ds = tf.data.Dataset.from_generator(gen, [tf.uint8]*28**2)  # does not work
    

提交回复
热议问题