title: "example xymasking" notebookName: "example_xymasking.ipynb"

Open in Google ColabRun this notebook interactively

XYMasking aug in Albumentations

This transform is a generalization of the TimeMasking and FrequencyMasking from torchaudio.

import albumentations as A
from matplotlib import pyplot as plt
/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
import numpy as np
import torch
import torchaudio
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 import torchaudio
      2 import torch
      4 import numpy as np
ModuleNotFoundError: No module named 'torchaudio'
def visualize(image):
    plt.figure(figsize=(10, 5))
    plt.axis("off")
    plt.imshow(image)
img = np.load("../images/spectrogram.npy")
img.shape
(128, 256, 4)
visualize(img[:, :, 0])
No code provided

png

No code provided

One vertical stripe (time masking) with fixed width, filled with 0

params1 = {
    "num_masks_x": 1,
    "mask_x_length": 20,
    "fill_value": 0,
}
transform1 = A.Compose([A.XYMasking(**params1, p=1)], strict=True, seed=137)
visualize(transform1(image=img[:, :, 0])["image"])
No code provided

png

No code provided

One vertical stripe (time masking) with randomly sampled width, filled with 0

params2 = {
    "num_masks_x": 1,
    "mask_x_length": (0, 20),  # This line changed from fixed  to a range
    "fill_value": 0,
}
transform2 = A.Compose([A.XYMasking(**params2, p=1)], strict=True, seed=137)
visualize(transform2(image=img[:, :, 0])["image"])
No code provided

png

No code provided

Analogous transform in torchaudio¶

spectrogram = torchaudio.transforms.Spectrogram()
masking = torchaudio.transforms.TimeMasking(time_mask_param=20)
masked = masking(torch.from_numpy(img[:, :, 0]))
visualize(masked.numpy())
No code provided

png

No code provided

One horizontal stripe (frequency masking) with randomly sampled width, filled with 0

params3 = {
    "num_masks_y": 1,
    "mask_y_length": (0, 20),
    "fill_value": 0,
}
transform3 = A.Compose([A.XYMasking(**params3, p=1)], strict=True, seed=137)
visualize(transform3(image=img[:, :, 0])["image"])
No code provided

png

No code provided

Analogous transform in torchaudio

spectrogram = torchaudio.transforms.Spectrogram()
masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)
masked = masking(torch.from_numpy(img[:, :, 0]))
visualize(masked)
No code provided

png

No code provided

Several vertical and horizontal stripes

params4 = {
    "num_masks_x": (2, 4),
    "num_masks_y": 5,
    "mask_y_length": 8,
    "mask_x_length": (10, 20),
    "fill_value": 0,
}
transform4 = A.Compose([A.XYMasking(**params4, p=1)], strict=True, seed=137)
visualize(transform4(image=img[:, :, 0])["image"])
No code provided

png

No code provided

Application to the image with the number of channels larger than 3, and different fill values for different channels

Transform can work with any number of channels supporing image shapes of

  • Grayscale: (height, width)
  • RGB: (height, width, 3)
  • Multichannel: (heigh, width, num_channels)

For value that is used to fill masking regions you can use:

  • scalar that will be applied to every channel
  • list of numbers equal to the number of channels, so that every channel will have it's own filling value
params5 = {
    "num_masks_x": (2, 4),
    "num_masks_y": 5,
    "mask_y_length": 8,
    "mask_x_length": (20, 30),
    "fill_value": (0, 1, 2, 3),
}
transform5 = A.Compose([A.XYMasking(**params5, p=1)], strict=True, seed=137)
transformed = transform5(image=img)["image"]
fig, axs = plt.subplots(2, 2)
vmin = 0
vmax = 3

axs[0, 0].imshow(transformed[:, :, 0], vmin=vmin, vmax=vmax)
axs[0, 0].set_title("Channel 0")
axs[0, 0].axis("off")  # Hide axes for cleaner visualization

axs[0, 1].imshow(transformed[:, :, 1], vmin=vmin, vmax=vmax)
axs[0, 1].set_title("Channel 1")
axs[0, 1].axis("off")

axs[1, 0].imshow(transformed[:, :, 2], vmin=vmin, vmax=vmax)
axs[1, 0].set_title("Channel 2")
axs[1, 0].axis("off")

axs[1, 1].imshow(transformed[:, :, 3], vmin=vmin, vmax=vmax)
axs[1, 1].set_title("Channel 3")
axs[1, 1].axis("off")

plt.tight_layout()

plt.show()
No code provided

png

No code provided