How to use different data augmentation for Subsets in PyTorch

大城市里の小女人 提交于 2020-01-24 17:04:29


How to use different data augmentation (transforms) for different Subsets in PyTorch?

For instance:

train, test =, [80000, 2000])

train and test will have the same transforms as dataset. How to use custom transforms for these subsets?


My current solution is not very elegant, but works:

from copy import copy

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)

test_dataset.dataset.transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])

train_dataset.dataset.transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])

Basically, I'm defining a new dataset (which is a copy of the original dataset) for one of the splits, and then I define a custom transform for each split.

Note: train_dataset.dataset.transform works since I'm using an ImageFolder dataset, which uses the .tranform attribute to perform the transforms.

If anybody knows a better solution, please share with us!


I've given up and copied my own Subset (almost identical to pytorch). I keep the transform in the Subset (not the parent).

class Subset(Dataset):
    Subset of a dataset at specified indices.

        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        im, labels = self.dataset[self.indices[idx]]
        return self.transform(im), labels

    def __len__(self):
        return len(self.indices)

you'll also have to write your own split funciton


This is what I use (taken from here):

import torch
from import Dataset, TensorDataset, random_split
from torchvision import transforms

class DatasetFromSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

Here's an example:

init_dataset = TensorDataset(
    torch.randn(100, 3, 24, 24),
    torch.randint(0, 10, (100,))

lengths = [int(len(init_dataset)*0.8), int(len(init_dataset)*0.2)]
train_subset, test_subset = random_split(init_dataset, lengths)

train_dataset = DatasetFromSubset(
    train_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
test_dataset = DatasetFromSubset(
    test_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))

