How do you load MNIST images into Pytorch DataLoader?

前端 未结 2 442
野的像风
野的像风 2021-01-30 09:27

The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple lo

2条回答
  •  余生分开走
    2021-01-30 09:55

    Here's what I did for pytorch 0.4.1 (should still work in 1.3)

    def load_dataset():
        data_path = 'data/train/'
        train_dataset = torchvision.datasets.ImageFolder(
            root=data_path,
            transform=torchvision.transforms.ToTensor()
        )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=64,
            num_workers=0,
            shuffle=True
        )
        return train_loader
    
    for batch_idx, (data, target) in enumerate(load_dataset()):
        #train network
    

提交回复
热议问题