问题
I am new to pytorch. I am trying to create a DataLoader for a dataset of images where each image got a corresponding ground truth (same name):
root:
--->RGB:
------>img1.png
------>img2.png
------>...
------>imgN.png
--->GT:
------>img1.png
------>img2.png
------>...
------>imgN.png
When I use the path for root folder (that contains RGB and GT folders) as input for the torchvision.datasets.ImageFolder
it reads all of the images as if they were all intended for input (classified as RGB and GT), and it seems like there is no way to pair the RGB-GT images. I would like to pair the RGB-GT images, shuffle, and divide it to batches of defined size. How can it be done? Any advice will be appreciated.
Thanks.
回答1:
I think, the good starting point is to use VisionDataset
class as a base. What we are going to use here is: DatasetFolder source code. So, we going to create smth similar. You can notice this class depends on two other functions from datasets.folder
module: default_loader and make_dataset.
We are not going to modify default_loader
, because it's already fine, it just helps us to load images, so we will import it.
But we need a new make_dataset
function, that prepared the right pairs of images from root folder. Since original make_dataset
pairs images (image paths if to be more precisely) and their root folder as target class (class index) and we have a list of (path, class_to_idx[target])
pairs, but we need (rgb_path, gt_path)
. Here is the code for new make_dataset
:
def make_dataset(root: str) -> list:
"""Reads a directory with data.
Returns a dataset as a list of tuples of paired image paths: (rgb_path, gt_path)
"""
dataset = []
# Our dir names
rgb_dir = 'RGB'
gt_dir = 'GT'
# Get all the filenames from RGB folder
rgb_fnames = sorted(os.listdir(os.path.join(root, rgb_dir)))
# Compare file names from GT folder to file names from RGB:
for gt_fname in sorted(os.listdir(os.path.join(root, gt_dir))):
if gt_fname in rgb_fnames:
# if we have a match - create pair of full path to the corresponding images
rgb_path = os.path.join(root, rgb_dir, gt_fname)
gt_path = os.path.join(root, gt_dir, gt_fname)
item = (rgb_path, gt_path)
# append to the list dataset
dataset.append(item)
else:
continue
return dataset
What do we have now? Let's compare our function with original one:
from torchvision.datasets.folder import make_dataset as make_dataset_original
dataset_original = make_dataset_original(root, {'RGB': 0, 'GT': 1}, extensions='png')
dataset = make_dataset(root)
print('Original make_dataset:')
print(*dataset_original, sep='\n')
print('Our make_dataset:')
print(*dataset, sep='\n')
Original make_dataset:
('./data/GT/img1.png', 1)
('./data/GT/img2.png', 1)
...
('./data/RGB/img1.png', 0)
('./data/RGB/img2.png', 0)
...
Our make_dataset:
('./data/RGB/img1.png', './data/GT/img1.png')
('./data/RGB/img2.png', './data/GT/img2.png')
...
I think it works great) It's time to create our class Dataset. The most important part here is __getitem__
methods, because it imports images, applies transformation and returns a tensors, that can be used by dataloaders. We need to read a pair of images (rgb and gt) and return a tuple of 2 tensor images:
from torchvision.datasets.folder import default_loader
from torchvision.datasets.vision import VisionDataset
class CustomVisionDataset(VisionDataset):
def __init__(self,
root,
loader=default_loader,
rgb_transform=None,
gt_transform=None):
super().__init__(root,
transform=rgb_transform,
target_transform=gt_transform)
# Prepare dataset
samples = make_dataset(self.root)
self.loader = loader
self.samples = samples
# list of RGB images
self.rgb_samples = [s[1] for s in samples]
# list of GT images
self.gt_samples = [s[1] for s in samples]
def __getitem__(self, index):
"""Returns a data sample from our dataset.
"""
# getting our paths to images
rgb_path, gt_path = self.samples[index]
# import each image using loader (by default it's PIL)
rgb_sample = self.loader(rgb_path)
gt_sample = self.loader(gt_path)
# here goes tranforms if needed
# maybe we need different tranforms for each type of image
if self.transform is not None:
rgb_sample = self.transform(rgb_sample)
if self.target_transform is not None:
gt_sample = self.target_transform(gt_sample)
# now we return the right imported pair of images (tensors)
return rgb_sample, gt_sample
def __len__(self):
return len(self.samples)
Let's test it:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
bs=4 # batch size
transforms = ToTensor() # we need this to convert PIL images to Tensor
shuffle = True
dataset = CustomVisionDataset('./data', rgb_transform=transforms, gt_transform=transforms)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=shuffle)
for i, (rgb, gt) in enumerate(dataloader):
print(f'batch {i+1}:')
# some plots
for i in range(bs):
plt.figure(figsize=(10, 5))
plt.subplot(221)
plt.imshow(rgb[i].squeeze().permute(1, 2, 0))
plt.title(f'RGB img{i+1}')
plt.subplot(222)
plt.imshow(gt[i].squeeze().permute(1, 2, 0))
plt.title(f'GT img{i+1}')
plt.show()
Out:
batch 1:
...
Here you can find a notebook with code and simple dummy dataset.
来源:https://stackoverflow.com/questions/59467781/pytorch-dataloader-for-image-gt-dataset