keras ImageDataGenerator interpolates binary mask

你离开我真会死。 提交于 2021-01-27 06:34:13

问题


I am training a neural network to predict a binary mask on mouse brain images. For this I am augmenting my data with the ImageDataGenerator from keras.

But I have realized that the Data Generator is interpolating the data when applying spatial transformations.

This is fine for the image, but I certainly do not want my mask to contain non-binary values.

Is there any way to choose something like a nearest neighbor interpolation when applying the transformations? I have found no such option in the keras documentation.

(To the left is the original binary mask, to the right is the augmented, interpolated mask)

Code for the images:

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest')
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, seed=1)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(image))
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(image_generator.next()[0]))
plt.axis('off')
plt.savefig('vis/keras_example')

回答1:


I had the same problem with my own binary image data. There are several ways to approach this issue.

Simple answer: I solved it by manually converting results of ImageDataGenerator to binary. If you are manually iterating over the generator(using 'next()' method or using a 'for' loop), so you can simply use numpy 'where' method to convert non-binary values to binary:

import numpy as np

batch = image_generator.next()
binary_images = np.where(batch>0, 1, 0)  ## or batch>0.5 or any other thresholds

Using the preprocessing_function argument in ImageDataGenerator

Another better way is to use preprocessing_function argument in the ImageDataGenerator. As written in the documentation it is possible to specify a custom preprocessing function that will be executed after the data augmentation procedures, so you can specify this function in your data_gen_args as follows:

from keras.preprocessing.image import ImageDataGenerator

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest',
                     preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype))

Note: from my experience the preprocessing_function is executed before the rescale, that is possible to specify also as an argument of the ImageDataGenerator in your data_gen_args. This is not your case but if you will need to specify that argument keep this in mind.

Create a custom generator

Another solution is to write a custom data generator and modify the output of ImageDataGenerator inside it. Then use this new generator to feed model.fit(). Something like this:

batch_size = 64
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, batch_size=batch_size, seed=1)
from tensorflow.keras.utils import Sequence
class MyImageDataGenerator(Sequence):
        def __init__(self, data_size, batch_size):
            self.data_size = data_size
            self.batch_size = batch_size
            super(MyImageDataGenerator).__init__()

        def __len__(self):
            return int(np.ceil(self.data_size / float(self.batch_size)))

        def __getitem__(self, idx):    
            augmented_data = image_generator.next()
            binary_images = np.where(augmented_data>0, 1, 0)
            return binary_images

my_image_generator = MyImageDataGenerator(data_size=len(image), batch_size=batch_size)
model.fit(my_image_generator, epochs=50)

Also above data generator is a simple data generator. If you need, you can customize it and add your lables (like this) or multimodal data, etc.



来源:https://stackoverflow.com/questions/60521736/keras-imagedatagenerator-interpolates-binary-mask

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!