问题
I have a CSV file with filename in the first column and a label for the filename in the second column. I also have a third column, which specifies something about the data (whether the data meets a specific condition). It will look something like,
+-----------------------------+
| Filepath 1 Label 1 'n' |
| |
+-----------------------------+
| Filepath 2 Label 2 'n' |
| |
| |
+-----------------------------+
| Filepath 3 Label 3 'n'|
| |
+-----------------------------+
| Filepath 4 Label 4 'y'|
+------------------------------+
I want to be able to load the custom dataset using getitem only when attribute column == 'y'. However, I get the following error:
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>
My code is as follows:
'''
class InterDataset(Dataset):
def __init__(self, csv_file, mode, root_dir = None, transform = None, run = None):
self.annotations = pd.read_csv(csv_file, header = None)
self.root_dir = root_dir
self.transform = transform
self.mode = mode
self.run = run
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
if self.mode == 'train':
if (self.annotations.iloc[index, 2] == 'n'):
img_path = self.annotations.iloc[index,0]
image = cv2.imread(img_path,1)
y_label = self.annotations.iloc[index,1]
if self.transform:
image = self.transform(image)
if (index+1)%300 == 0:
print('Loop {0} done'.format(index))
return [image, y_label]
'''
回答1:
You get that error because the dataloader has to return something. Here are three solutions:
- There is a libary called nonechucks which lets you create dataloaders in which you can skip samples.
- Usually you could preprocess/clean your data and kick the unwanted samples out.
- You could return some indicator that the sample is unwanted, for example
if "y":
return data, target
else:
return -1
And then you could check in your train loop if the "data" is -1 and skip the iteration. I hope this was helpful :)
来源:https://stackoverflow.com/questions/64603234/loading-data-from-custom-data-loader-in-pytorch-only-if-the-data-specifies-a-cer