The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple lo
If you're using mnist, there's already a preset in pytorch via torchvision.
You could do
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
shuffle=True, num_workers=2)
If you want to generalize to a directory of images (same imports as above), you could do
class mnistmTrainingDataset(torch.utils.data.Dataset):
def __init__(self,text_file,root_dir,transform=transformMnistm):
"""
Args:
text_file(string): path to text file
root_dir(string): directory with all train images
"""
self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.name_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
image = Image.open(img_name)
image = self.transform(image)
labels = self.label_frame.iloc[idx, 0]
#labels = labels.reshape(-1, 2)
sample = {'image': image, 'labels': labels}
return sample
mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
root_dir = 'Downloads/mnist_m/mnist_m_train')
mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)
You can then iterate over it like:
for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
print("training sample for mnist-m")
print(i_batch,sample_batched['image'],sample_batched['labels'])
There are a bunch of ways to generalize pytorch for image dataset loading, the method that I know of is subclassing torch.utils.data.dataset