Tensorflow: create minibatch from numpy array > 2 GB

前端 未结 1 994
清酒与你
清酒与你 2021-01-21 10:01

I am trying to feed minibatches of numpy arrays to my model, but I\'m stuck with batching. Using \'tf.train.shuffle_batch\' raises an error because the \'images\' array is large

1条回答
  •  傲寒
    傲寒 (楼主)
    2021-01-21 10:23

    You are using the initializable iterator of tf.Data to feed data to your model. This means that you can parametrize the dataset in terms of placeholders, and then call an initializer op for the iterator to prepare it for use.

    In case you use the initializable iterator, or any other iterator from tf.Data to feed inputs to your model, you should not use the feed_dict argument of sess.run to try to do data feeding. Instead, define your model in terms of the outputs of iterator.get_next() and omit the feed_dict from sess.run.

    Something along these lines:

    iterator = dataset.make_initializable_iterator()
    image_batch, label_batch = iterator.get_next()
    
    # use get_next outputs to define model
    model = Model(config, image_batch, label_batch) 
    
    # placeholders fed in while initializing the iterator
    sess.run(iterator.initializer, 
                feed_dict={images_placeholder: images,
                           labels_placeholder: labels})
    
    for step in xrange(steps):
         # iterator will feed image and label in the background
         sess.run(model.optimize) 
    

    The iterator will feed data to your model in the background, additional feeding via feed_dict is not necessary.

    0 讨论(0)
提交回复
热议问题