Keras ImageDataGenerator for multiple inputs and image based target output

前端 未结 2 464
小蘑菇
小蘑菇 2021-01-16 04:14

I have a model which takes two Images as inputs and generates a single image as a Target output.

All of my training image-data is in the following sub-folders:

2条回答
  •  生来不讨喜
    2021-01-16 05:03

    The custom class JoinedGen shown in Daniel Möller's post works great if one does not want (or need) to shuffle the training examples. However, more often than not reshuffling at the end of each epoch is highly desirable for the learning process. Fortunately, this can be easily accomplished as well. First, the three ImageDataGenerators should use shuffle = True (important).

    idg1 = ImageDataGenerator(...choose params...)
    idg2 = ImageDataGenerator(...choose params...)
    idg3 = ImageDataGenerator(...choose params...)
    
    gen1 = idg1.flow_from_directory('input1_dir',
                                    shuffle=True,
                                    class_mode=None)
    gen2 = idg2.flow_from_directory('input2_dir',
                                    shuffle=True,
                                    class_mode=None)
    gen3 = idg3.flow_from_directory('target_dir',
                                    shuffle=True,
                                    class_mode=None)
    

    If nothing else is done the three generators will get out of sync after the first epoch since each of them will be reshuffled differently at the end of the first epoch. To make them stay in sync one needs to add two lines at the end of the function on_epoch_end(self). Namely, the last two generators should get the same array of indices as the first generator:

    class JoinedGen(tf.keras.utils.Sequence):
        def __init__(self, input_gen1, input_gen2, target_gen):
            self.gen1 = input_gen1
            self.gen2 = input_gen2
            self.gen3 = target_gen
    
            assert len(input_gen1) == len(input_gen2) == len(target_gen)
    
        def __len__(self):
            return len(self.gen1)
    
        def __getitem__(self, i):
            x1 = self.gen1[i]
            x2 = self.gen2[i]
            y = self.gen3[i]
    
            return [x1, x2], y
    
        def on_epoch_end(self):
            self.gen1.on_epoch_end()
            self.gen2.on_epoch_end()
            self.gen3.on_epoch_end()
            self.gen2.index_array = self.gen1.index_array
            self.gen3.index_array = self.gen1.index_array
    

提交回复
热议问题