Skip to content

Migrating from torchvision to Albumentations

This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.

Import the required libraries

Python
from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

An example pipeline that uses torchvision

Python
class TorchvisionDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        # Read an image with PIL
        image = Image.open(file_path)
        if self.transform:
            image = self.transform(image)
        return image, label


torchvision_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
])


torchvision_dataset = TorchvisionDataset(
    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=torchvision_transform,
)

The same pipeline with Albumentations

Python
class AlbumentationsDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        # Read an image with OpenCV
        image = cv2.imread(file_path)

        # By default OpenCV uses BGR color space for color images,
        # so we need to convert the image to RGB color space.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label


albumentations_transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(224, 224),
    A.HorizontalFlip(),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])


albumentations_dataset = AlbumentationsDataset(
    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=albumentations_transform,
)

Using albumentations with PIL

You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.

Python
class AlbumentationsPilDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]

        image = Image.open(file_path)

        if self.transform:
            # Convert PIL image to numpy array
            image_np = np.array(image)
            # Apply transformations
            augmented = self.transform(image=image_np)
            # Convert numpy array to PIL Image
            image = Image.fromarray(augmented['image'])
        return image, label


albumentations_pil_transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(224, 224),
    A.HorizontalFlip(),
])


# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
    labels=[1, 2, 3],
    transform=albumentations_pil_transform,
)

Albumentations equivalents for torchvision transforms

torchvision transform Albumentations transform Albumentations example
Compose Compose A.Compose([A.Resize(256, 256), A.RandomCrop(224, 224)])
CenterCrop CenterCrop A.CenterCrop(256, 256)
ColorJitter HueSaturationValue A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5)
Pad PadIfNeeded A.PadIfNeeded(min_height=512, min_width=512)
RandomAffine Affine A.Affine(scale=(0.9, 1.1), translate_percent=(0.0, 0.2), rotate=(-45, 45), shear=(-15, 15), mode=cv2.BORDER_REFLECT_101, p=0.5)
RandomCrop RandomCrop A.RandomCrop(256, 256)
RandomGrayscale ToGray A.ToGray(p=0.5)
RandomHorizontalFlip HorizontalFlip A.HorizontalFlip(p=0.5)
RandomPerspective Perspective A.Perspective(scale=(0.2, 0.4), fit_output=True, p=0.5)
RandomRotation Rotate A.Rotate(limit=45, p=0.5)
RandomVerticalFlip VerticalFlip A.VerticalFlip(p=0.5)
Resize Resize A.Resize(256, 256)
GaussianBlur GaussianBlur A.GaussianBlur(blur_limit=(3, 7), p=0.5)
RandomInvert InvertImg A.InvertImg(p=0.5)
RandomPosterize Posterize A.Posterize(num_bits=4, p=0.5)
RandomSolarize Solarize A.Solarize(threshold=127, p=0.5)
RandomAdjustSharpness Sharpen A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5)
RandomAutocontrast RandomBrightnessContrast A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2, p=0.5)
RandomEqualize Equalize A.Equalize(p=0.5)
RandomErasing CoarseDropout A.CoarseDropout(min_height=8, max_height=32, min_width=8, max_width=32, p=0.5)
Normalize Normalize A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])