How to iterate over two dataloaders simultaneously using pytorch?

后端 未结 5 619
星月不相逢
星月不相逢 2021-02-06 06:04

I am trying to implement a Siamese network that takes in two images. I load these images and create two separate dataloaders.

In my loop I want to go through both datalo

5条回答
  •  南方客
    南方客 (楼主)
    2021-02-06 06:34

    To complete @ManojAcharya's answer:

    The error you are getting comes neither from zip() nor DataLoader() directly. Python is trying to tell you that it couldn't find one of the data files you are asking for (c.f. FileNotFoundError in the exception trace), probably in your Dataset.

    Find below a working example using DataLoader and zip together. Note that if you want to shuffle your data, it becomes difficult to keep the correspondences between the 2 datasets. This justifies @ManojAcharya's solution.

    import torch
    from torch.utils.data import DataLoader, Dataset
    
    class DummyDataset(Dataset):
        """
        Dataset of numbers in [a,b] inclusive
        """
    
        def __init__(self, a=0, b=100):
            super(DummyDataset, self).__init__()
            self.a = a
            self.b = b
    
        def __len__(self):
            return self.b - self.a + 1
    
        def __getitem__(self, index):
            return index, "label_{}".format(index)
    
    dataloaders1 = DataLoader(DummyDataset(0, 9), batch_size=2, shuffle=True)
    dataloaders2 = DataLoader(DummyDataset(0, 9), batch_size=2, shuffle=True)
    
    for i, data in enumerate(zip(dataloaders1, dataloaders2)):
        print(data)
    # ([tensor([ 4,  7]), ('label_4', 'label_7')], [tensor([ 8,  5]), ('label_8', 'label_5')])
    # ([tensor([ 1,  9]), ('label_1', 'label_9')], [tensor([ 6,  9]), ('label_6', 'label_9')])
    # ([tensor([ 6,  5]), ('label_6', 'label_5')], [tensor([ 0,  4]), ('label_0', 'label_4')])
    # ([tensor([ 8,  2]), ('label_8', 'label_2')], [tensor([ 2,  7]), ('label_2', 'label_7')])
    # ([tensor([ 0,  3]), ('label_0', 'label_3')], [tensor([ 3,  1]), ('label_3', 'label_1')])
    

提交回复
热议问题