Keras ImageDataGenerator for multiple inputs and image based target output

前端 未结 2 465
小蘑菇
小蘑菇 2021-01-16 04:14

I have a model which takes two Images as inputs and generates a single image as a Target output.

All of my training image-data is in the following sub-folders:

2条回答
  •  栀梦
    栀梦 (楼主)
    2021-01-16 05:14

    One possibility is to join three ImageDataGenerator into one, using class_mode=None (so they don't return any target), and using shuffle=False (important). Make sure you're using the same batch_size for each and make sure each input is in a different dir, and the targets also in a different dir, and that there are exactly the same number of images in each directory.

    idg1 = ImageDataGenerator(...choose params...)
    idg2 = ImageDataGenerator(...choose params...)
    idg3 = ImageDataGenerator(...choose params...)
    
    gen1 = idg1.flow_from_directory('input1_dir',
                                    shuffle=False,
                                    class_mode=None)
    gen2 = idg2.flow_from_directory('input2_dir',
                                    shuffle=False,
                                    class_mode=None)
    gen3 = idg3.flow_from_directory('target_dir',
                                    shuffle=False,
                                    class_mode=None)
    

    Create a custom generator:

    class JoinedGen(tf.keras.utils.Sequence):
        def __init__(self, input_gen1, input_gen2, target_gen):
            self.gen1 = input_gen1
            self.gen2 = input_gen2
            self.gen3 = target_gen
    
            assert len(input_gen1) == len(input_gen2) == len(target_gen)
    
        def __len__(self):
            return len(self.gen1)
    
        def __getitem__(self, i):
            x1 = self.gen1[i]
            x2 = self.gen2[i]
            y = self.gen3[i]
    
            return [x1, x2], y
    
        def on_epoch_end(self):
            self.gen1.on_epoch_end()
            self.gen2.on_epoch_end()
            self.gen3.on_epoch_end()
    

    Train with this generator:

    my_gen = JoinedGen(gen1, gen2, gen3)
    model.fit_generator(my_gen, ...)
    

    Or create a custom generator. All the structure for it is shown above.

提交回复
热议问题