问题
I am trying to create a custom data generator and don't know how integrate the yield
function combined with an infinite loop inside the __getitem__
method.
EDIT: After the answer I realized that the code I am using is a Sequence
which doesn't need a yield
statement.
Currently I am returning multiple images with a return
statement:
class DataGenerator(tensorflow.keras.utils.Sequence):
def __init__(self, files, labels, batch_size=32, shuffle=True, random_state=42):
'Initialization'
self.files = files
self.labels = labels
self.batch_size = batch_size
self.shuffle = shuffle
self.random_state = random_state
self.on_epoch_end()
def __len__(self):
return int(np.floor(len(self.files) / self.batch_size))
def __getitem__(self, index):
# Generate indexes of the batch
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
files_batch = [self.files[k] for k in indexes]
y = [self.labels[k] for k in indexes]
# Generate data
x = self.__data_generation(files_batch)
return x, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.files))
if self.shuffle == True:
np.random.seed(self.random_state)
np.random.shuffle(self.indexes)
def __data_generation(self, files):
imgs = []
for img_file in files:
img = cv2.imread(img_file, -1)
###############
# Augment image
###############
imgs.append(img)
return imgs
In this article I saw that yield
is used in an infinite loop. I don't quite understand that syntax. How is the loop escaped?
回答1:
You are using the Sequence API, which works a bit different than plain generators. In a generator function, you would use the yield
keyword to perform iteration inside a while True:
loop, so each time Keras calls the generator, it gets a batch of data and it automatically wraps around the end of the data.
But in a Sequence, there is an index
parameter to the __getitem__
function, so no iteration or yield
is required, this is performed by Keras for you. This is made so the sequence can run in parallel using multiprocessing, which is not possible with old generator functions.
So you are doing things the right way, there is no change needed.
回答2:
Example of generator in Keras
:
def datagenerator(images, labels, batchsize, mode="train"):
while True:
start = 0
end = batchsize
while start < len(images):
# load your images from numpy arrays or read from directory
x = images[start:end]
y = labels[start:end]
yield x, y
start += batchsize
end += batchsize
Keras wants you to have the infinite loop running in the generator.
If you want to learn about Python generators, then the link in the comments is actually a good place to start.
来源:https://stackoverflow.com/questions/56079223/custom-keras-data-generator-with-yield