I am trying to implement a Siamese network that takes in two images. I load these images and create two separate dataloaders.
In my loop I want to go through both datalo
Further to what it is already mentioned, cycle()
and zip()
might create a memory leakage problem - especially when using image datasets! To solve that, instead of iterating like this:
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
for i, (data1, data2) in enumerate(zip(cycle(dataloaders1), dataloaders2)):
do_cool_things()
you could use:
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
dataloader_iterator = iter(dataloaders1)
for i, data1 in enumerate(dataloaders2)):
try:
data2 = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(dataloaders1)
data2 = next(dataloader_iterator)
do_cool_things()
Bear in mind that if you use labels as well, you should replace in this example data1
with (inputs1,targets1)
and data2
with inputs2,targets2
, as @Sajad Norouzi said.
KUDOS to this one: https://github.com/pytorch/pytorch/issues/1917#issuecomment-433698337