Tensorflow: tf.data.Dataset, Cannot batch tensors with different shapes in component 0

后端 未结 2 2157
暗喜
暗喜 2021-02-15 17:57

I have the following error in my input pipeline:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shape

2条回答
  •  挽巷
    挽巷 (楼主)
    2021-02-15 18:02

    First case: we want the output to have fixed batch size

    In this case, the generator generates values of shape [None, 48, 48, 3] where the first dimension could be anything. We want to batch this so that the output is [batch_size, 48, 48, 3]. If we use directly tf.data.Dataset.batch, we will have an error, so we need to unbatch first.

    To do that we can use tf.contrib.data.unbatch like this before batching:

    dataset = dataset.apply(tf.contrib.data.unbatch())
    dataset = dataset.batch(batch_size)
    

    Here is a full example where the generator yields [1], [2, 2], [3, 3, 3] and [4, 4, 4, 4].

    We can't batch these output values directly, so we unbatch and then batch them:

    def gen():
        for i in range(1, 5):
            yield [i] * i
    
    # Create dataset from generator
    # The output shape is variable: (None,)
    dataset = tf.data.Dataset.from_generator(gen, tf.int64, tf.TensorShape([None]))
    
    # The issue here is that we want to batch the data
    dataset = dataset.apply(tf.contrib.data.unbatch())
    dataset = dataset.batch(2)
    
    # Create iterator from dataset
    iterator = dataset.make_one_shot_iterator()
    x = iterator.get_next()  # shape (None,)
    
    sess = tf.Session()
    for i in range(5):
        print(sess.run(x))
    

    This will print the following output:

    [1 2]
    [2 3]
    [3 3]
    [4 4]
    [4 4]
    

    Second case: we want to concatenate variable sized batches

    Update (03/30/2018): I removed the previous answer that used sharding which slows down performance by a lot (see comments).

    In this case, we want to concatenate a fixed number of batches. The issue is that these batches have variable sizes. For instance the dataset yields [1] and [2, 2] and we want to get [1, 2, 2] as the output.

    Here a quick way to solve this is to create a new generator wrapping around the original one. The new generator will yield batched data. (Thanks to Guillaume for the idea)


    Here is a full example where the generator yields [1], [2, 2], [3, 3, 3] and [4, 4, 4, 4].

    def gen():
        for i in range(1, 5):
            yield [i] * i
    
    def get_batch_gen(gen, batch_size=2):
        def batch_gen():
            buff = []
            for i, x in enumerate(gen()):
                if i % batch_size == 0 and buff:
                    yield np.concatenate(buff, axis=0)
                    buff = []
                buff += [x]
    
            if buff:
                yield np.concatenate(buff, axis=0)
    
        return batch_gen
    
    # Create dataset from generator
    batch_size = 2
    dataset = tf.data.Dataset.from_generator(get_batch_gen(gen, batch_size),
                                             tf.int64, tf.TensorShape([None]))
    
    # Create iterator from dataset
    iterator = dataset.make_one_shot_iterator()
    x = iterator.get_next()  # shape (None,)
    
    
    with tf.Session() as sess:
        for i in range(2):
            print(sess.run(x))
    

    This will print the following output:

    [1 2 2]
    [3 3 3 4 4 4 4]
    

提交回复
热议问题