Migrating from torchvision to Albumentations 🔗

This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.

Import the required libraries 🔗

import albumentations as A
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm

An example pipeline that uses torchvision 🔗

class TorchvisionDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
 
    def __len__(self):
        return len(self.file_paths)
 
    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]
 
        # Read an image with PIL
        image = Image.open(file_path)
        if self.transform:
            image = self.transform(image)
        return image, label
 
 
torchvision_transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ],
)
 
 
torchvision_dataset = TorchvisionDataset(
    file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
    labels=[1, 2, 3],
    transform=torchvision_transform,
)

The same pipeline with Albumentations 🔗

class AlbumentationsDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
 
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
 
    def __len__(self):
        return len(self.file_paths)
 
    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]
 
        # Read an image with OpenCV
        image = cv2.imread(file_path)
 
        # By default OpenCV uses BGR color space for color images,
        # so we need to convert the image to RGB color space.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]
        return image, label
 
 
albumentations_transform = A.Compose(
    [
        A.Resize(256, 256),
        A.RandomCrop(224, 224),
        A.HorizontalFlip(),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2(),
    ],
)
 
 
albumentations_dataset = AlbumentationsDataset(
    file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
    labels=[1, 2, 3],
    transform=albumentations_transform,
)

Using albumentations with PIL 🔗

You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.

class AlbumentationsPilDataset(Dataset):
    """__init__ and __len__ functions are the same as in TorchvisionDataset"""
 
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
 
    def __len__(self):
        return len(self.file_paths)
 
    def __getitem__(self, idx):
        label = self.labels[idx]
        file_path = self.file_paths[idx]
 
        image = Image.open(file_path)
 
        if self.transform:
            # Convert PIL image to numpy array
            image_np = np.array(image)
            # Apply transformations
            augmented = self.transform(image=image_np)
            # Convert numpy array to PIL Image
            image = Image.fromarray(augmented["image"])
        return image, label
 
 
albumentations_pil_transform = A.Compose(
    [
        A.Resize(256, 256),
        A.RandomCrop(224, 224),
        A.HorizontalFlip(),
    ],
)
 
 
# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
    file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
    labels=[1, 2, 3],
    transform=albumentations_pil_transform,
)