Your ad could be here - Reach CV/ML engineers
Contact for advertisingexampleexample bboxesexample bboxes2example chromatic aberrationexample d4example documentsexample domain adaptationexample gridshuffleexample hfhubexample kaggle saltexample keypointsexample mosaicexample multi targetexample OverlayElementsexample textimageexample weather transformsexample xymaskingmigrating from torchvision to albumentationspytorch classificationpytorch semantic segmentationreplayserializationshowcase
exampleexample bboxesexample bboxes2example chromatic aberrationexample d4example documentsexample domain adaptationexample gridshuffleexample hfhubexample kaggle saltexample keypointsexample mosaicexample multi targetexample OverlayElementsexample textimageexample weather transformsexample xymaskingmigrating from torchvision to albumentationspytorch classificationpytorch semantic segmentationreplayserializationshowcase
Migrating from torchvision to Albumentations 🔗
This notebook shows how you can use Albumentations instead of torchvision to perform data augmentation.
Import the required libraries 🔗
import albumentations as A
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
An example pipeline that uses torchvision 🔗
class TorchvisionDataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with PIL
image = Image.open(file_path)
if self.transform:
image = self.transform(image)
return image, label
torchvision_transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
],
)
torchvision_dataset = TorchvisionDataset(
file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
labels=[1, 2, 3],
transform=torchvision_transform,
)
The same pipeline with Albumentations 🔗
class AlbumentationsDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with OpenCV
image = cv2.imread(file_path)
# By default OpenCV uses BGR color space for color images,
# so we need to convert the image to RGB color space.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented["image"]
return image, label
albumentations_transform = A.Compose(
[
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensorV2(),
],
)
albumentations_dataset = AlbumentationsDataset(
file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
labels=[1, 2, 3],
transform=albumentations_transform,
)
Using albumentations with PIL 🔗
You can use PIL instead of OpenCV while working with Albumentations, but in that case, you need to convert a PIL image to a NumPy array before applying transformations. Them you need to convert the augmented image back from a NumPy array to a PIL image.
class AlbumentationsPilDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
image = Image.open(file_path)
if self.transform:
# Convert PIL image to numpy array
image_np = np.array(image)
# Apply transformations
augmented = self.transform(image=image_np)
# Convert numpy array to PIL Image
image = Image.fromarray(augmented["image"])
return image, label
albumentations_pil_transform = A.Compose(
[
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
],
)
# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
file_paths=["./images/image_1.jpg", "./images/image_2.jpg", "./images/image_3.jpg"],
labels=[1, 2, 3],
transform=albumentations_pil_transform,
)