Migrating from torchvision to Albumentations 🔗
This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.
Import the required libraries 🔗
import albumentations as A
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
An example pipeline that uses torchvision 🔗
class TorchvisionDataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with PIL
image = Image.open(file_path)
if self.transform:
image = self.transform(image)
return image, label
torchvision_transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
],
)
torchvision_dataset = TorchvisionDataset(
file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
labels=[1, 2, 3],
transform=torchvision_transform,
)
The same pipeline with Albumentations 🔗
class AlbumentationsDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with OpenCV
image = cv2.imread(file_path)
# By default OpenCV uses BGR color space for color images,
# so we need to convert the image to RGB color space.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented["image"]
return image, label
albumentations_transform = A.Compose(
[
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensorV2(),
],
)
albumentations_dataset = AlbumentationsDataset(
file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
labels=[1, 2, 3],
transform=albumentations_transform,
)
Using albumentations with PIL 🔗
You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.
class AlbumentationsPilDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
image = Image.open(file_path)
if self.transform:
# Convert PIL image to numpy array
image_np = np.array(image)
# Apply transformations
augmented = self.transform(image=image_np)
# Convert numpy array to PIL Image
image = Image.fromarray(augmented["image"])
return image, label
albumentations_pil_transform = A.Compose(
[
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
],
)
# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
labels=[1, 2, 3],
transform=albumentations_pil_transform,
)