MixUp transform in Albumentations¶
In that transform we create weighted average of original and reference images. Transform also supports global_labels and masks
In [ ]:
Copied!
import random
import cv2
from matplotlib import pyplot as plt
from pathlib import Path
import numpy as np
import cv2
import albumentations as A
import random import cv2 from matplotlib import pyplot as plt from pathlib import Path import numpy as np import cv2 import albumentations as A
Define a function to visualize an image¶
In [ ]:
Copied!
def visualize(image):
plt.figure(figsize=(10, 5))
plt.axis('off')
plt.imshow(image)
def visualize(image): plt.figure(figsize=(10, 5)) plt.axis('off') plt.imshow(image)
In [ ]:
Copied!
def load_rgb(image_path):
image = cv2.imread(image_path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
def load_rgb(image_path): image = cv2.imread(image_path) return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
Load the image from the disk¶
In [ ]:
Copied!
img_path = "../images/woman.jpeg"
img = load_rgb(img_path)
img_path = "../images/woman.jpeg" img = load_rgb(img_path)
In [ ]:
Copied!
class_id = 0
class_id = 0
Visualize the original image¶
In [ ]:
Copied!
visualize(img)
visualize(img)
Mixup transform¶
To use transform we need to define reference data that could be any sequence or generator.
We also need to defined read_fn
that transforms items from reference_data
to dictionaries with keys: image
, and optional global_label
, mask
.
In [ ]:
Copied!
reference_data = [{
"image_path": "../images/cat1.jpeg",
"class_id": 1},
{"image_path": "../images/tiger.jpeg",
"class_id": 2}]
reference_data = [{ "image_path": "../images/cat1.jpeg", "class_id": 1}, {"image_path": "../images/tiger.jpeg", "class_id": 2}]
In [ ]:
Copied!
def int_to_onehot(value, num_classes):
"""Convert an array of integers to one-hot representation.
Args:
values (np.ndarray): Array of integers to be converted.
num_classes (int): Total number of classes, determines the length of one-hot vectors.
Returns:
np.ndarray: One-hot encoded representation of `values`.
"""
# Initialize the one-hot encoded array of shape (num_classes,)
one_hot = np.zeros(num_classes, dtype=int)
# Set the appropriate index to one
one_hot[value] = 1
return one_hot
def int_to_onehot(value, num_classes): """Convert an array of integers to one-hot representation. Args: values (np.ndarray): Array of integers to be converted. num_classes (int): Total number of classes, determines the length of one-hot vectors. Returns: np.ndarray: One-hot encoded representation of `values`. """ # Initialize the one-hot encoded array of shape (num_classes,) one_hot = np.zeros(num_classes, dtype=int) # Set the appropriate index to one one_hot[value] = 1 return one_hot
In [ ]:
Copied!
NUM_CLASSES = 5
NUM_CLASSES = 5
In [ ]:
Copied!
target_height = 2500
target_width = 1800
target_height = 2500 target_width = 1800
In [ ]:
Copied!
# We can process data as we want, including application of augmentations transform.
reference_aug = A.Compose([A.RandomCrop(width=target_width, height=target_height, p=1)], p=1)
# We can process data as we want, including application of augmentations transform. reference_aug = A.Compose([A.RandomCrop(width=target_width, height=target_height, p=1)], p=1)
In [ ]:
Copied!
def read_fn(item):
image = load_rgb(item["image_path"])
transformed_image = reference_aug(image=image)["image"]
global_label = int_to_onehot(item["class_id"], NUM_CLASSES)
return {
"image": transformed_image,
"global_label": global_label
}
def read_fn(item): image = load_rgb(item["image_path"]) transformed_image = reference_aug(image=image)["image"] global_label = int_to_onehot(item["class_id"], NUM_CLASSES) return { "image": transformed_image, "global_label": global_label }
Show reference images¶
In [ ]:
Copied!
visualize(read_fn(reference_data[0])["image"])
visualize(read_fn(reference_data[0])["image"])
In [ ]:
Copied!
visualize(read_fn(reference_data[1])["image"])
visualize(read_fn(reference_data[1])["image"])
In [ ]:
Copied!
transform = A.Compose([A.RandomCrop(width=target_width, height=target_height, p=1),
A.MixUp(reference_data=reference_data,
read_fn=read_fn, p=1),
A.HorizontalFlip(p=1)], p=1)
transform = A.Compose([A.RandomCrop(width=target_width, height=target_height, p=1), A.MixUp(reference_data=reference_data, read_fn=read_fn, p=1), A.HorizontalFlip(p=1)], p=1)
In [ ]:
Copied!
original_global_label = int_to_onehot(class_id, NUM_CLASSES)
original_global_label = int_to_onehot(class_id, NUM_CLASSES)
In [ ]:
Copied!
transformed = transform(image=img, global_label=original_global_label)
print("Global label = ", transformed["global_label"])
print("Mixing coefficient = ", transformed["mix_coef"])
visualize(transformed["image"])
transformed = transform(image=img, global_label=original_global_label) print("Global label = ", transformed["global_label"]) print("Mixing coefficient = ", transformed["mix_coef"]) visualize(transformed["image"])
Global label = [0.62160869 0.37839131 0. 0. 0. ] Mixing coefficient = 0.6216086930047383
In [ ]:
Copied!
transformed = transform(image=img, global_label=original_global_label)
print("Global label = ", transformed["global_label"])
print("Mixing coefficient = ", transformed["mix_coef"])
visualize(transformed["image"])
transformed = transform(image=img, global_label=original_global_label) print("Global label = ", transformed["global_label"]) print("Mixing coefficient = ", transformed["mix_coef"]) visualize(transformed["image"])
Global label = [0.54018375 0. 0.45981625 0. 0. ] Mixing coefficient = 0.5401837474762494