Migrating from torchvision to Albumentations¶
This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.
Import the required libraries¶
In [4]:
Copied!
from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image import cv2 import numpy as np from torch.utils.data import Dataset from torchvision import transforms import albumentations as A from albumentations.pytorch import ToTensorV2
An example pipeline that uses torchvision¶
In [2]:
Copied!
class TorchvisionDataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with PIL
image = Image.open(file_path)
if self.transform:
image = self.transform(image)
return image, label
torchvision_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
])
torchvision_dataset = TorchvisionDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=torchvision_transform,
)
class TorchvisionDataset(Dataset): def __init__(self, file_paths, labels, transform=None): self.file_paths = file_paths self.labels = labels self.transform = transform def __len__(self): return len(self.file_paths) def __getitem__(self, idx): label = self.labels[idx] file_path = self.file_paths[idx] # Read an image with PIL image = Image.open(file_path) if self.transform: image = self.transform(image) return image, label torchvision_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) ]) torchvision_dataset = TorchvisionDataset( file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'], labels=[1, 2, 3], transform=torchvision_transform, )
The same pipeline with Albumentations¶
In [3]:
Copied!
class AlbumentationsDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with OpenCV
image = cv2.imread(file_path)
# By default OpenCV uses BGR color space for color images,
# so we need to convert the image to RGB color space.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, label
albumentations_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensorV2()
])
albumentations_dataset = AlbumentationsDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_transform,
)
class AlbumentationsDataset(Dataset): """__init__ and __len__ functions are the same as in TorchvisionDataset""" def __init__(self, file_paths, labels, transform=None): self.file_paths = file_paths self.labels = labels self.transform = transform def __len__(self): return len(self.file_paths) def __getitem__(self, idx): label = self.labels[idx] file_path = self.file_paths[idx] # Read an image with OpenCV image = cv2.imread(file_path) # By default OpenCV uses BGR color space for color images, # so we need to convert the image to RGB color space. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.transform: augmented = self.transform(image=image) image = augmented['image'] return image, label albumentations_transform = A.Compose([ A.Resize(256, 256), A.RandomCrop(224, 224), A.HorizontalFlip(), A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ToTensorV2() ]) albumentations_dataset = AlbumentationsDataset( file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'], labels=[1, 2, 3], transform=albumentations_transform, )
Using albumentations with PIL¶
You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.
In [4]:
Copied!
class AlbumentationsPilDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
image = Image.open(file_path)
if self.transform:
# Convert PIL image to numpy array
image_np = np.array(image)
# Apply transformations
augmented = self.transform(image=image_np)
# Convert numpy array to PIL Image
image = Image.fromarray(augmented['image'])
return image, label
albumentations_pil_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
])
# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_pil_transform,
)
class AlbumentationsPilDataset(Dataset): """__init__ and __len__ functions are the same as in TorchvisionDataset""" def __init__(self, file_paths, labels, transform=None): self.file_paths = file_paths self.labels = labels self.transform = transform def __len__(self): return len(self.file_paths) def __getitem__(self, idx): label = self.labels[idx] file_path = self.file_paths[idx] image = Image.open(file_path) if self.transform: # Convert PIL image to numpy array image_np = np.array(image) # Apply transformations augmented = self.transform(image=image_np) # Convert numpy array to PIL Image image = Image.fromarray(augmented['image']) return image, label albumentations_pil_transform = A.Compose([ A.Resize(256, 256), A.RandomCrop(224, 224), A.HorizontalFlip(), ]) # Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors albumentations_pil_dataset = AlbumentationsPilDataset( file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'], labels=[1, 2, 3], transform=albumentations_pil_transform, )
Albumentations equivalents for torchvision transforms¶
torchvision transform | Albumentations transform | Albumentations example |
---|---|---|
Compose | Compose | A.Compose([A.Resize(256, 256), A.RandomCrop(224, 224)]) |
CenterCrop | CenterCrop | A.CenterCrop(256, 256) |
ColorJitter | HueSaturationValue | A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5) |
Pad | PadIfNeeded | A.PadIfNeeded(min_height=512, min_width=512) |
RandomAffine | Affine | A.Affine(scale=(0.9, 1.1), translate_percent=(0.0, 0.2), rotate=(-45, 45), shear=(-15, 15), mode=cv2.BORDER_REFLECT_101, p=0.5) |
RandomCrop | RandomCrop | A.RandomCrop(256, 256) |
RandomGrayscale | ToGray | A.ToGray(p=0.5) |
RandomHorizontalFlip | HorizontalFlip | A.HorizontalFlip(p=0.5) |
RandomPerspective | Perspective | A.Perspective(scale=(0.2, 0.4), fit_output=True, p=0.5) |
RandomRotation | Rotate | A.Rotate(limit=45, p=0.5) |
RandomVerticalFlip | VerticalFlip | A.VerticalFlip(p=0.5) |
Resize | Resize | A.Resize(256, 256) |
GaussianBlur | GaussianBlur | A.GaussianBlur(blur_limit=(3, 7), p=0.5) |
RandomInvert | InvertImg | A.InvertImg(p=0.5) |
RandomPosterize | Posterize | A.Posterize(num_bits=4, p=0.5) |
RandomSolarize | Solarize | A.Solarize(threshold=127, p=0.5) |
RandomAdjustSharpness | Sharpen | A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5) |
RandomAutocontrast | RandomBrightnessContrast | A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2, p=0.5) |
RandomEqualize | Equalize | A.Equalize(p=0.5) |
RandomErasing | CoarseDropout | A.CoarseDropout(min_height=8, max_height=32, min_width=8, max_width=32, p=0.5) |
Normalize | Normalize | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |