The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple lo
Here's what I did for pytorch 0.4.1 (should still work in 1.3)
def load_dataset():
data_path = 'data/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=64,
num_workers=0,
shuffle=True
)
return train_loader
for batch_idx, (data, target) in enumerate(load_dataset()):
#train network