How to save and load parameters of an augmentation pipeline¶
Reproducibility is very important in deep learning. Data scientists and machine learning engineers need a way to save all parameters of deep learning pipelines such as model, optimizer, input datasets, and augmentation parameters and to be able to recreate the same pipeline using that data. Albumentations has built-in functionality to serialize the augmentation parameters and save them. Then you can use those parameters to recreate an augmentation pipeline.
Import the required libraries¶
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import albumentations as A
Define the visualization function¶
def visualize(image):
plt.figure(figsize=(6, 6))
plt.axis('off')
plt.imshow(image)
Load an image from the disk¶
image = cv2.imread('images/parrot.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
visualize(image)
Define an augmentation pipeline that we want to serialize¶
transform = A.Compose([
A.Perspective(),
A.RandomCrop(768, 768),
A.OneOf([
A.RGBShift(),
A.HueSaturationValue()
]),
])
We can pass an instance of augmentation to the print
function, and it will print the string representation of it.
print(transform)
Compose([ RandomCrop(always_apply=False, p=1.0, height=768, width=768), OneOf([ RGBShift(always_apply=False, p=0.5, r_shift_limit=(-20, 20), g_shift_limit=(-20, 20), b_shift_limit=(-20, 20)), HueSaturationValue(always_apply=False, p=0.5, hue_shift_limit=(-20, 20), sat_shift_limit=(-30, 30), val_shift_limit=(-20, 20)), ], p=0.5), ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={})
Next, we will fix the random seed to make augmentation reproducible for visualization purposes and augment an example image.
random.seed(42)
np.random.seed(42)
transformed = transform(image=image)
visualize(transformed['image'])
Serializing an augmentation pipeline to a JSON or YAML file¶
To save the serialized representation of an augmentation pipeline to a JSON file, use the save
function from Albumentations.
A.save(transform, '/tmp/transform.json')
To load a serialized representation from a JSON file, use the load
function from Albumentations.
loaded_transform = A.load('/tmp/transform.json')
print(loaded_transform)
Compose([ RandomCrop(always_apply=False, p=1.0, height=768, width=768), OneOf([ RGBShift(always_apply=False, p=0.5, r_shift_limit=(-20, 20), g_shift_limit=(-20, 20), b_shift_limit=(-20, 20)), HueSaturationValue(always_apply=False, p=0.5, hue_shift_limit=(-20, 20), sat_shift_limit=(-30, 30), val_shift_limit=(-20, 20)), ], p=0.5), ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={})
Next, we will use the same random seed as before and apply the loaded augmentation pipeline to the same image.
random.seed(42)
transformed_from_loaded_transform = loaded_transform(image=image)
visualize(transformed_from_loaded_transform['image'])