Migrating from torchvision to Albumentations¶
This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.
Import the required libraries¶
Python
from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
An example pipeline that uses torchvision¶
Python
class TorchvisionDataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with PIL
image = Image.open(file_path)
if self.transform:
image = self.transform(image)
return image, label
torchvision_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
])
torchvision_dataset = TorchvisionDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=torchvision_transform,
)
The same pipeline with Albumentations¶
Python
class AlbumentationsDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with OpenCV
image = cv2.imread(file_path)
# By default OpenCV uses BGR color space for color images,
# so we need to convert the image to RGB color space.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, label
albumentations_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensorV2()
])
albumentations_dataset = AlbumentationsDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_transform,
)
Using albumentations with PIL¶
You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.
Python
class AlbumentationsPilDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
image = Image.open(file_path)
if self.transform:
# Convert PIL image to numpy array
image_np = np.array(image)
# Apply transformations
augmented = self.transform(image=image_np)
# Convert numpy array to PIL Image
image = Image.fromarray(augmented['image'])
return image, label
albumentations_pil_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
])
# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_pil_transform,
)
Albumentations equivalents for torchvision transforms¶
torchvision transform | Albumentations transform | Albumentations example |
---|---|---|
Compose | Compose | A.Compose([A.Resize(256, 256), A.RandomCrop(224, 224)]) |
CenterCrop | CenterCrop | A.CenterCrop(256, 256) |
ColorJitter | HueSaturationValue | A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5) |
Pad | PadIfNeeded | A.PadIfNeeded(min_height=512, min_width=512) |
RandomAffine | Affine | A.Affine(scale=(0.9, 1.1), translate_percent=(0.0, 0.2), rotate=(-45, 45), shear=(-15, 15), mode=cv2.BORDER_REFLECT_101, p=0.5) |
RandomCrop | RandomCrop | A.RandomCrop(256, 256) |
RandomGrayscale | ToGray | A.ToGray(p=0.5) |
RandomHorizontalFlip | HorizontalFlip | A.HorizontalFlip(p=0.5) |
RandomPerspective | Perspective | A.Perspective(scale=(0.2, 0.4), fit_output=True, p=0.5) |
RandomRotation | Rotate | A.Rotate(limit=45, p=0.5) |
RandomVerticalFlip | VerticalFlip | A.VerticalFlip(p=0.5) |
Resize | Resize | A.Resize(256, 256) |
GaussianBlur | GaussianBlur | A.GaussianBlur(blur_limit=(3, 7), p=0.5) |
RandomInvert | InvertImg | A.InvertImg(p=0.5) |
RandomPosterize | Posterize | A.Posterize(num_bits=4, p=0.5) |
RandomSolarize | Solarize | A.Solarize(threshold=127, p=0.5) |
RandomAdjustSharpness | Sharpen | A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5) |
RandomAutocontrast | RandomBrightnessContrast | A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2, p=0.5) |
RandomEqualize | Equalize | A.Equalize(p=0.5) |
RandomErasing | CoarseDropout | A.CoarseDropout(min_height=8, max_height=32, min_width=8, max_width=32, p=0.5) |
Normalize | Normalize | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |