Stay updated
News & Insightsexampleexample bboxesexample bboxes2example chromatic aberrationexample d4example documentsexample domain adaptationexample gridshuffleexample hfhubexample kaggle saltexample keypointsexample mosaicexample multi targetexample OverlayElementsexample textimageexample weather transformsexample xymaskingmigrating from torchvision to albumentationspytorch classificationpytorch semantic segmentationreplayserializationshowcase
XYMasking aug in Albumentations π
This transform is a generalization of the TimeMasking and FrequencyMasking from torchaudio.
import albumentations as A
from matplotlib import pyplot as plt
/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
import numpy as np
import torch
import torchaudio
ModuleNotFoundError Traceback (most recent call last)
Cell In[2], line 1 -- */} 1 import torchaudio 2 import torch 4 import numpy as np
ModuleNotFoundError: No module named 'torchaudio'
def visualize(image):
plt.figure(figsize=(10, 5))
plt.axis("off")
plt.imshow(image)
img = np.load("../images/spectrogram.npy")
img.shape
(128, 256, 4)
visualize(img[:, :, 0])

One vertical stripe (time masking) with fixed width, filled with 0 π
params1 = {
"num_masks_x": 1,
"mask_x_length": 20,
"fill_value": 0,
}
transform1 = A.Compose([A.XYMasking(**params1, p=1)], strict=True, seed=137)
visualize(transform1(image=img[:, :, 0])["image"])

One vertical stripe (time masking) with randomly sampled width, filled with 0 π
params2 = {
"num_masks_x": 1,
"mask_x_length": (0, 20), # This line changed from fixed to a range
"fill_value": 0,
}
transform2 = A.Compose([A.XYMasking(**params2, p=1)], strict=True, seed=137)
visualize(transform2(image=img[:, :, 0])["image"])

Analogous transform in torchaudioΒΆ π
spectrogram = torchaudio.transforms.Spectrogram()
masking = torchaudio.transforms.TimeMasking(time_mask_param=20)
masked = masking(torch.from_numpy(img[:, :, 0]))
visualize(masked.numpy())

One horizontal stripe (frequency masking) with randomly sampled width, filled with 0 π
params3 = {
"num_masks_y": 1,
"mask_y_length": (0, 20),
"fill_value": 0,
}
transform3 = A.Compose([A.XYMasking(**params3, p=1)], strict=True, seed=137)
visualize(transform3(image=img[:, :, 0])["image"])

Analogous transform in torchaudio π
spectrogram = torchaudio.transforms.Spectrogram()
masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)
masked = masking(torch.from_numpy(img[:, :, 0]))
visualize(masked)

Several vertical and horizontal stripes π
params4 = {
"num_masks_x": (2, 4),
"num_masks_y": 5,
"mask_y_length": 8,
"mask_x_length": (10, 20),
"fill_value": 0,
}
transform4 = A.Compose([A.XYMasking(**params4, p=1)], strict=True, seed=137)
visualize(transform4(image=img[:, :, 0])["image"])

Application to the image with the number of channels larger than 3, and different fill values for different channels π
Transform can work with any number of channels supporing image shapes of
- Grayscale: (height, width)
- RGB: (height, width, 3)
- Multichannel: (heigh, width, num_channels)
For value that is used to fill masking regions you can use:
- scalar that will be applied to every channel
- list of numbers equal to the number of channels, so that every channel will have it's own filling value
params5 = {
"num_masks_x": (2, 4),
"num_masks_y": 5,
"mask_y_length": 8,
"mask_x_length": (20, 30),
"fill_value": (0, 1, 2, 3),
}
transform5 = A.Compose([A.XYMasking(**params5, p=1)], strict=True, seed=137)
transformed = transform5(image=img)["image"]
fig, axs = plt.subplots(2, 2)
vmin = 0
vmax = 3
axs[0, 0].imshow(transformed[:, :, 0], vmin=vmin, vmax=vmax)
axs[0, 0].set_title("Channel 0")
axs[0, 0].axis("off") # Hide axes for cleaner visualization
axs[0, 1].imshow(transformed[:, :, 1], vmin=vmin, vmax=vmax)
axs[0, 1].set_title("Channel 1")
axs[0, 1].axis("off")
axs[1, 0].imshow(transformed[:, :, 2], vmin=vmin, vmax=vmax)
axs[1, 0].set_title("Channel 2")
axs[1, 0].axis("off")
axs[1, 1].imshow(transformed[:, :, 3], vmin=vmin, vmax=vmax)
axs[1, 1].set_title("Channel 3")
axs[1, 1].axis("off")
plt.tight_layout()
plt.show()
