How do you load MNIST images into Pytorch DataLoader?

前端 未结 2 441
野的像风
野的像风 2021-01-30 09:27

The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple lo

2条回答
  •  孤街浪徒
    2021-01-30 09:59

    If you're using mnist, there's already a preset in pytorch via torchvision.
    You could do

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import pandas as pd
    
    transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
    mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                          shuffle=True, num_workers=2)
    

    If you want to generalize to a directory of images (same imports as above), you could do

    class mnistmTrainingDataset(torch.utils.data.Dataset):
    
        def __init__(self,text_file,root_dir,transform=transformMnistm):
            """
            Args:
                text_file(string): path to text file
                root_dir(string): directory with all train images
            """
            self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
            self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
            self.root_dir = root_dir
            self.transform = transform
    
        def __len__(self):
            return len(self.name_frame)
    
        def __getitem__(self, idx):
            img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
            image = Image.open(img_name)
            image = self.transform(image)
            labels = self.label_frame.iloc[idx, 0]
            #labels = labels.reshape(-1, 2)
            sample = {'image': image, 'labels': labels}
    
            return sample
    
    
    mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                       root_dir = 'Downloads/mnist_m/mnist_m_train')
    
    mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)
    

    You can then iterate over it like:

    for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
        print("training sample for mnist-m")
        print(i_batch,sample_batched['image'],sample_batched['labels'])
    

    There are a bunch of ways to generalize pytorch for image dataset loading, the method that I know of is subclassing torch.utils.data.dataset

提交回复
热议问题