Loading data from Custom Data-Loader in pytorch only if the data specifies a certain condition

南笙酒味 提交于 2021-01-29 09:32:14

问题


I have a CSV file with filename in the first column and a label for the filename in the second column. I also have a third column, which specifies something about the data (whether the data meets a specific condition). It will look something like,

+-----------------------------+
| Filepath 1   Label 1    'n' |
|                             |
+-----------------------------+
| Filepath 2   Label 2    'n' |
|                             |
|                             |
+-----------------------------+
| Filepath 3   Label 3     'n'|
|                             |
+-----------------------------+
| Filepath 4   Label 4     'y'|
+------------------------------+

I want to be able to load the custom dataset using getitem only when attribute column == 'y'. However, I get the following error:

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>

My code is as follows:

'''

class InterDataset(Dataset):
  def __init__(self, csv_file, mode, root_dir = None, transform = None, run = None):
    self.annotations = pd.read_csv(csv_file, header = None)
    self.root_dir = root_dir
    self.transform = transform
    self.mode = mode
    self.run = run

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):
    if self.mode == 'train':
        if (self.annotations.iloc[index, 2] == 'n'):
                    img_path = self.annotations.iloc[index,0]
                    image = cv2.imread(img_path,1)
                    
        
                    y_label = self.annotations.iloc[index,1]

                    if self.transform:
                        image = self.transform(image)
                    if (index+1)%300 == 0:
                        print('Loop {0} done'.format(index))
                    return [image, y_label]

    
            

'''


回答1:


You get that error because the dataloader has to return something. Here are three solutions:

  1. There is a libary called nonechucks which lets you create dataloaders in which you can skip samples.
  2. Usually you could preprocess/clean your data and kick the unwanted samples out.
  3. You could return some indicator that the sample is unwanted, for example
if "y":
    return data, target
else:
    return -1

And then you could check in your train loop if the "data" is -1 and skip the iteration. I hope this was helpful :)



来源:https://stackoverflow.com/questions/64603234/loading-data-from-custom-data-loader-in-pytorch-only-if-the-data-specifies-a-cer

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!