How to make a generator callable?

后端 未结 3 616
渐次进展
渐次进展 2021-01-04 10:12

I\'m trying to create a dataset from a CSV file with 784-bit long rows. Here\'s my code:

import tensorflow as tf

f = open(\"test.csv\", \"r\")
csvreader = c         


        
3条回答
  •  被撕碎了的回忆
    2021-01-04 10:27

    Yuck, two years later... But hey! Another solution! :D

    This might not be the cleanest answer but for generators that are more complicated, you can use a decorator. I made a generator that yields two dictionaries, for example:

    >>> train,val = dataloader("path/to/dataset")
    >>> x,y = next(train)
    >>> print(x)
    {"data": [...], "filename": "image.png"}
    
    >>> print(y)
    {"category": "Dog", "category_id": 1, "background": "park"}
    

    When I tried using the from_generator, it gave me the error:

    >>> ds_tf = tf.data.Dataset.from_generator(
        iter(mm),
        ({"data":tf.float32, "filename":tf.string},
        {"category":tf.string, "category_id":tf.int32, "background":tf.string})
        )
    TypeError: `generator` must be callable.
    

    But then I wrote a decorating function

    >>> def make_gen_callable(_gen):
            def gen():
                for x,y in _gen:
                     yield x,y
            return gen
    >>> train_ = make_gen_callable(train)
    
    >>> train_ds = tf.data.Dataset.from_generator(
        train_,
        ({"data":tf.float32, "filename":tf.string},
        {"category":tf.string, "category_id":tf.int32, "background":tf.string})
        )
    
    >>> for x,y in train_ds:
            break
    
    >>> print(x)
    {'data': ,
     'filename':  
    }
    
    >>> print(y)
    {'category': ,
     'category_id': ,
     'background': 
    }
    

    But now, note that in order to iterate train_, one has to call it

    >>> for x,y in train_():
            do_stuff(x,y)
            ...
    

提交回复
热议问题