Migrating from torchvision to Albumentations¶
This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.
Import the required libraries¶
In [1]:
from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
An example pipeline that uses torchvision¶
In [2]:
class TorchvisionDataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with PIL
image = Image.open(file_path)
if self.transform:
image = self.transform(image)
return image, label
torchvision_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
])
torchvision_dataset = TorchvisionDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=torchvision_transform,
)
The same pipeline with Albumentations¶
In [3]:
class AlbumentationsDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with OpenCV
image = cv2.imread(file_path)
# By default OpenCV uses BGR color space for color images,
# so we need to convert the image to RGB color space.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, label
albumentations_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensorV2()
])
albumentations_dataset = AlbumentationsDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_transform,
)
Using albumentations with PIL¶
You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.
In [4]:
class AlbumentationsPilDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
image = Image.open(file_path)
if self.transform:
# Convert PIL image to numpy array
image_np = np.array(image)
# Apply transformations
augmented = self.transform(image=image_np)
# Convert numpy array to PIL Image
image = Image.fromarray(augmented['image'])
return image, label
albumentations_pil_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
])
# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_pil_transform,
)
Albumentations equivalents for torchvision transforms¶
torchvision transform | Albumentations transform | Albumentations example |
---|---|---|
Compose | Compose | A.Compose([A.Resize(256, 256), A.RandomCrop(224, 224)]) |
CenterCrop | CenterCrop | A.CenterCrop(256, 256) |
ColorJitter | HueSaturationValue | A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5) |
Pad | PadIfNeeded | A.PadIfNeeded(min_height=512, min_width=512) |
RandomAffine | Affine | A.Affine(scale=(0.9, 1.1), translate_percent=(0.0, 0.2), rotate=(-45, 45), shear=(-15, 15), mode=cv2.BORDER_REFLECT_101, p=0.5) |
RandomCrop | RandomCrop | A.RandomCrop(256, 256) |
RandomGrayscale | ToGray | A.ToGray(p=0.5) |
RandomHorizontalFlip | HorizontalFlip | A.HorizontalFlip(p=0.5) |
RandomPerspective | Perspective | A.Perspective(scale=(0.2, 0.4), fit_output=True, p=0.5) |
RandomRotation | Rotate | A.Rotate(limit=45, p=0.5) |
RandomVerticalFlip | VerticalFlip | A.VerticalFlip(p=0.5) |
Resize | Resize | A.Resize(256, 256) |
GaussianBlur | GaussianBlur | A.GaussianBlur(blur_limit=(3, 7), p=0.5) |
RandomInvert | InvertImg | A.InvertImg(p=0.5) |
RandomPosterize | Posterize | A.Posterize(num_bits=4, p=0.5) |
RandomSolarize | Solarize | A.Solarize(threshold=127, p=0.5) |
RandomAdjustSharpness | Sharpen | A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5) |
RandomAutocontrast | RandomBrightnessContrast | A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2, p=0.5) |
RandomEqualize | Equalize | A.Equalize(p=0.5) |
RandomErasing | CoarseDropout | A.CoarseDropout(min_height=8, max_height=32, min_width=8, max_width=32, p=0.5) |
Normalize | Normalize | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |