问题
I am training a neural network to predict a binary mask on mouse brain images. For this I am augmenting my data with the ImageDataGenerator from keras.
But I have realized that the Data Generator is interpolating the data when applying spatial transformations.
This is fine for the image, but I certainly do not want my mask to contain non-binary values.
Is there any way to choose something like a nearest neighbor interpolation when applying the transformations? I have found no such option in the keras documentation.
(To the left is the original binary mask, to the right is the augmented, interpolated mask)
Code for the images:
data_gen_args = dict(rotation_range=90,
width_shift_range=30,
height_shift_range=30,
shear_range=5,
zoom_range=0.3,
horizontal_flip=True,
vertical_flip=True,
fill_mode='nearest')
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, seed=1)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(image))
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(image_generator.next()[0]))
plt.axis('off')
plt.savefig('vis/keras_example')
回答1:
I had the same problem with my own binary image data. There are several ways to approach this issue.
Simple answer: I solved it by manually converting results of ImageDataGenerator to binary. If you are manually iterating over the generator(using 'next()' method or using a 'for' loop), so you can simply use numpy 'where' method to convert non-binary values to binary:
import numpy as np
batch = image_generator.next()
binary_images = np.where(batch>0, 1, 0) ## or batch>0.5 or any other thresholds
Using the preprocessing_function
argument in ImageDataGenerator
Another better way is to use preprocessing_function
argument in the ImageDataGenerator
. As written in the documentation it is possible to specify a custom preprocessing function that will be executed after the data augmentation procedures, so you can specify this function in your data_gen_args
as follows:
from keras.preprocessing.image import ImageDataGenerator
data_gen_args = dict(rotation_range=90,
width_shift_range=30,
height_shift_range=30,
shear_range=5,
zoom_range=0.3,
horizontal_flip=True,
vertical_flip=True,
fill_mode='nearest',
preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype))
Note: from my experience the preprocessing_function
is executed before the rescale
, that is possible to specify also as an argument of the ImageDataGenerator
in your data_gen_args
. This is not your case but if you will need to specify that argument keep this in mind.
Create a custom generator
Another solution is to write a custom data generator and modify the output of ImageDataGenerator inside it. Then use this new generator to feed model.fit()
. Something like this:
batch_size = 64
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, batch_size=batch_size, seed=1)
from tensorflow.keras.utils import Sequence
class MyImageDataGenerator(Sequence):
def __init__(self, data_size, batch_size):
self.data_size = data_size
self.batch_size = batch_size
super(MyImageDataGenerator).__init__()
def __len__(self):
return int(np.ceil(self.data_size / float(self.batch_size)))
def __getitem__(self, idx):
augmented_data = image_generator.next()
binary_images = np.where(augmented_data>0, 1, 0)
return binary_images
my_image_generator = MyImageDataGenerator(data_size=len(image), batch_size=batch_size)
model.fit(my_image_generator, epochs=50)
Also above data generator is a simple data generator. If you need, you can customize it and add your lables (like this) or multimodal data, etc.
来源:https://stackoverflow.com/questions/60521736/keras-imagedatagenerator-interpolates-binary-mask