How to iterate over two dataloaders simultaneously using pytorch?

后端 未结 5 621
星月不相逢
星月不相逢 2021-02-06 06:04

I am trying to implement a Siamese network that takes in two images. I load these images and create two separate dataloaders.

In my loop I want to go through both datalo

5条回答
  •  野的像风
    2021-02-06 06:46

    Further to what it is already mentioned, cycle() and zip() might create a memory leakage problem - especially when using image datasets! To solve that, instead of iterating like this:

    dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
    dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
    num_epochs = 10
    
    for epoch in range(num_epochs):
    
        for i, (data1, data2) in enumerate(zip(cycle(dataloaders1), dataloaders2)):
            
            do_cool_things()
    
    

    you could use:

    dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
    dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
    num_epochs = 10
    
    for epoch in range(num_epochs):
        dataloader_iterator = iter(dataloaders1)
        
        for i, data1 in enumerate(dataloaders2)):
    
            try:
                data2 = next(dataloader_iterator)
            except StopIteration:
                dataloader_iterator = iter(dataloaders1)
                data2 = next(dataloader_iterator)
    
            do_cool_things()
    

    Bear in mind that if you use labels as well, you should replace in this example data1 with (inputs1,targets1) and data2 with inputs2,targets2, as @Sajad Norouzi said.

    KUDOS to this one: https://github.com/pytorch/pytorch/issues/1917#issuecomment-433698337

提交回复
热议问题