Creating Custom Albumentations Transforms

On this page

While Albumentations provides a wide array of built-in transformations, you might need to create your own custom logic. This guide walks through the core patterns: choosing a base class, sampling parameters, applying them to one or more targets, validating configuration, and passing custom data.

Custom Transform Flow

Motivation

  • Implement novel augmentation techniques.
  • Create domain-specific transforms.
  • Wrap functions from external image processing libraries.
  • Encapsulate complex or conditional logic.

Quick Reference

# Base Classes (choose one)
A.ImageOnlyTransform  # Modifies only image pixels
A.DualTransform       # Modifies image + mask/bboxes/keypoints
A.Transform3D         # For 3D/volumetric data

# Methods you commonly implement
def __init__(self, p=0.5):                            # Store configuration
def get_params_dependent_on_data(self, params, data):  # Sample once per call
def apply(self, img, **params):                       # Transform image
def apply_to_mask(self, mask, **params):              # Optional target route
Use CaseBase ClassKey Methods
Change pixels onlyImageOnlyTransformapply
Change image + mask/bboxes/keypointsDualTransformapply, apply_to_mask, apply_to_bboxes, apply_to_keypoints
Handle 3D volumesTransform3Dapply_to_volume
Route project-specific dataUsually DualTransformtargets, apply_to_<key>

Core rules:

  • p=0.5 by default; pass p=1.0 for transforms that should always run.
  • Sample random values in get_params_dependent_on_data, not in apply.
  • Use self.py_random or self.random_generator, never global random or np.random, so A.Compose(..., seed=...) remains reproducible.

Core Concept: Inheriting from Base Classes

To integrate with Albumentations, custom transforms inherit from base classes like A.ImageOnlyTransform, A.DualTransform, or A.Transform3D. These base classes provide the structure for handling different data targets (image, mask, bboxes, etc.) and probabilities.

⚠️ Important: Default Probability

The BasicTransform class, which all transform classes inherit from, has a default value of p=0.5. This means if you don't explicitly specify the probability parameter in your custom transform's constructor, it will only be applied 50% of the time.

If you want your transform to be always applied, make sure to pass p=1.0 to the parent class constructor:

class MyTransform(A.ImageOnlyTransform):
    def __init__(self):
        # This transform will only be applied 50% of the time
        super().__init__()  # p defaults to 0.5

class MyAlwaysAppliedTransform(A.ImageOnlyTransform):
    def __init__(self):
        # This transform will always be applied
        super().__init__(p=1.0)

🟢 Step 1: Minimal Image-Only Transform (Essential)

Let's start with the simplest case: a transform that only modifies the image pixels and involves some randomness. We'll inherit from A.ImageOnlyTransform.

Goal: Create a transform that multiplies the image pixel values by a random factor between 0.5 and 1.5.

import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
import numpy as np

class RandomMultiplier(ImageOnlyTransform):
    """Multiplies the pixel values of an image by a random factor."""

    def __init__(self, factor_range=(0.5, 1.5), p=0.5):
        super().__init__(p=p)
        self.factor_range = factor_range

    def get_params_dependent_on_data(self, params, data):
        factor = self.py_random.uniform(*self.factor_range)
        return {"factor": factor}

    def apply(self, img, factor, **params):
        return np.clip(img * factor, 0, 255).astype(img.dtype)

# --- Usage ---
random_mult = RandomMultiplier(factor_range=(0.8, 1.2), p=1.0)

pipeline = A.Compose([
    A.Resize(256, 256),
    random_mult,
], seed=137)

get_params_dependent_on_data runs once per transform call. Values returned from it are passed to apply and any apply_to_* methods through **params. Use the transform's own random generator (self.py_random or self.random_generator) so seeded pipelines stay reproducible.


🟢 Step 2: Data-Dependent Image-Only Transform (Essential)

Now, let's create a transform where the parameters calculated depend on the input image itself. We'll still use ImageOnlyTransform as we only modify the image pixels.

Goal: Create a transform that normalizes the image using its per-channel mean and standard deviation.

import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
import numpy as np

class PerImageNormalize(ImageOnlyTransform):
    """Normalizes an image using its own mean and standard deviation."""

    def __init__(self, p=1.0):
        super().__init__(p=p)

    def get_params_dependent_on_data(self, params, data):
        img = data["image"]
        mean = np.mean(img, axis=(0, 1))
        std = np.std(img, axis=(0, 1))
        std = np.where(std < 1e-6, 1e-6, std)
        return {"mean": mean, "std": std}

    def apply(self, img, mean, std, **params):
        img = img.astype(np.float32)
        return ((img - mean) / std).astype(np.float32)

Use data["image"] when parameter values depend on the input. Here, the output is float32 because normalization can produce negative and non-integer values.


🟡 Step 3: Handling Multiple Images Efficiently (Intermediate)

ImageOnlyTransform can also receive an images array, for example a video clip or a stack of related frames. Override apply_to_images when the default per-image loop would be slow or semantically wrong.

Goal: Efficiently normalize each image in a sequence/batch using its own mean and standard deviation, while still inheriting from ImageOnlyTransform.

import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
import numpy as np

class PerImageNormalizeEfficient(ImageOnlyTransform):
    """
    Normalizes an image or a sequence/batch of images using per-image
    mean and standard deviation. Overrides apply_to_images for efficiency.
    """
    def __init__(self, p=1.0):
        super().__init__(p=p)

    def get_params_dependent_on_data(self, params, data):
        params_out = {}
        if "image" in data:
            img = data["image"]
            params_out["mean"] = np.mean(img, axis=(0, 1))
            params_out["std"] = np.std(img, axis=(0, 1))
        elif "images" in data:
            imgs = data["images"]
            params_out["mean"] = np.mean(imgs, axis=(1, 2))
            params_out["std"] = np.std(imgs, axis=(1, 2))
        else:
            return {}

        params_out["std"] = np.where(params_out["std"] < 1e-6, 1e-6, params_out["std"])
        return params_out

    def apply(self, img, mean, std, **params):
        img = img.astype(np.float32)
        return ((img - mean) / std).astype(np.float32)

    def apply_to_images(self, images, mean, std, **params):
        images = images.astype(np.float32)

        if images.ndim == 4:
            mean = mean[:, np.newaxis, np.newaxis, :]
            std = std[:, np.newaxis, np.newaxis, :]
        elif images.ndim == 3:
            mean = mean[:, np.newaxis, np.newaxis]
            std = std[:, np.newaxis, np.newaxis]

        return ((images - mean) / std).astype(np.float32)

The key difference from Step 2 is that get_params_dependent_on_data returns one mean/std pair per frame, and apply_to_images reshapes those arrays for vectorized broadcasting.


🟠 Step 4: Geometric Transform Affecting Multiple Targets (Advanced)

This step shows how to implement a custom geometric transform inheriting from A.DualTransform and manually applying the transformation logic. Remember that apply_to_bboxes and apply_to_keypoints receive data in a standardized internal format.

Goal: Create a transform that randomly shifts the image and all associated targets, implementing the target logic manually.

import albumentations as A
from albumentations.core.transforms_interface import DualTransform
import numpy as np
import cv2

class RandomShiftMultiTargetManual(DualTransform):
    """
    Randomly shifts the image and associated targets (mask, bboxes, keypoints)
    with manual implementation of target transformations, respecting internal formats.
    Processes arrays of bboxes and keypoints.
    """
    def __init__(self, x_limit=0.1, y_limit=0.1,
                 interpolation=cv2.INTER_LINEAR,
                 mask_interpolation=cv2.INTER_NEAREST,
                 border_mode=cv2.BORDER_CONSTANT,
                 fill=0,
                 fill_mask=0,
                 p=0.5):
        super().__init__(p=p)
        self.x_limit = x_limit
        self.y_limit = y_limit
        self.interpolation = interpolation
        self.mask_interpolation = mask_interpolation
        self.border_mode = border_mode
        self.fill = fill
        self.fill_mask = fill_mask

    def get_params_dependent_on_data(self, params, data):
        height, width = data["image"].shape[:2]

        dx = self.py_random.uniform(-self.x_limit, self.x_limit)
        dy = self.py_random.uniform(-self.y_limit, self.y_limit)

        x_shift = int(width * dx)
        y_shift = int(height * dy)
        matrix = np.float32([[1, 0, x_shift], [0, 1, y_shift]])
        return {
            "matrix": matrix,
            "x_shift": x_shift,
            "y_shift": y_shift,
            "height": height,
            "width": width,
        }

    def apply(self, img: np.ndarray, matrix: np.ndarray, height: int, width: int, **params):
        return cv2.warpAffine(
            img, matrix, (width, height),
            flags=self.interpolation, borderMode=self.border_mode, borderValue=self.fill
        )

    def apply_to_mask(self, mask: np.ndarray, matrix: np.ndarray, height: int, width: int, **params):
        return cv2.warpAffine(
            mask, matrix, (width, height),
            flags=self.mask_interpolation, borderMode=self.border_mode, borderValue=self.fill_mask
        )

    # Correct method name and signature for multiple bounding boxes
    def apply_to_bboxes(self, bboxes: np.ndarray, x_shift: int, y_shift: int, height: int, width: int, **params):
        """
        Applies shift to an array of bounding boxes.
        Assumes bboxes array is in internal normalized format [:, [x_min, y_min, x_max, y_max, ...]].
        Returns bboxes array in the same format.
        """
        # Calculate normalized shifts
        norm_dx = x_shift / width
        norm_dy = y_shift / height

        # Apply shifts vectorized
        bboxes_shifted = bboxes.copy()
        bboxes_shifted[:, [0, 2]] = bboxes[:, [0, 2]] + norm_dx
        bboxes_shifted[:, [1, 3]] = bboxes[:, [1, 3]] + norm_dy

        return bboxes_shifted

    # Correct method name and signature for multiple keypoints
    def apply_to_keypoints(self, keypoints: np.ndarray, x_shift: int, y_shift: int, **params):
        """
        Applies shift to an array of keypoints.
        Assumes keypoints array is in internal format [:, [x, y, angle, scale, ...]] where x, y are absolute pixels.
        Returns keypoints array in the same format.
        """
        keypoints_shifted = keypoints.copy()
        keypoints_shifted[:, 0] = keypoints[:, 0] + x_shift
        keypoints_shifted[:, 1] = keypoints[:, 1] + y_shift

        return keypoints_shifted

🟠 Custom Target Routing with apply_to_<key> (Advanced)

The standard target names route to standard methods:

  • image routes to apply
  • mask routes to apply_to_mask
  • bboxes routes to apply_to_bboxes
  • keypoints routes to apply_to_keypoints

For project-specific targets, override the targets property and map your custom key to a method on the transform:

@property
def targets(self):
    return {**super().targets, "camera": self.apply_to_camera}

Every target method receives the same parameters returned by get_params_dependent_on_data, so image pixels and custom metadata stay synchronized.

This is useful when the target is not just another image, mask, bounding box, or keypoint set. Camera intrinsics are a common example: if a random crop removes pixels from the top or left side of an image, the principal point must move by the same offset.

Goal: Create a random crop transform that crops the image and updates camera intrinsics stored in a plain dictionary. The important part is that apply and apply_to_camera receive the same sampled crop coordinates.

import albumentations as A
from albumentations.core.transforms_interface import DualTransform
import numpy as np
from typing import Any

class RandomCropWithIntrinsics(DualTransform):
    """
    Randomly crops an image and updates camera intrinsics that share
    the same sampled crop geometry.
    """

    def __init__(self, crop_height: int, crop_width: int, p: float = 1.0):
        super().__init__(p=p)
        self.crop_height = crop_height
        self.crop_width = crop_width

    @property
    def targets(self) -> dict[str, Any]:
        return {
            **super().targets,
            "camera": self.apply_to_camera,
        }

    def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]):
        image_height, image_width = data["image"].shape[:2]

        if self.crop_height > image_height or self.crop_width > image_width:
            raise ValueError(
                f"Crop size ({self.crop_height}, {self.crop_width}) must be smaller than "
                f"or equal to image size ({image_height}, {image_width})."
            )

        x_min = self.py_random.randint(0, image_width - self.crop_width)
        y_min = self.py_random.randint(0, image_height - self.crop_height)
        x_max = x_min + self.crop_width
        y_max = y_min + self.crop_height

        return {
            "x_min": x_min,
            "y_min": y_min,
            "x_max": x_max,
            "y_max": y_max,
        }

    def apply(
        self,
        img: np.ndarray,
        x_min: int,
        y_min: int,
        x_max: int,
        y_max: int,
        **params: Any,
    ) -> np.ndarray:
        return img[y_min:y_max, x_min:x_max]

    def apply_to_camera(
        self,
        camera: dict[str, float],
        x_min: int,
        y_min: int,
        **params: Any,
    ) -> dict[str, float]:
        camera = camera.copy()

        # Cropping changes the coordinate origin. Focal lengths do not change.
        camera["cx"] = camera["cx"] - x_min
        camera["cy"] = camera["cy"] - y_min
        camera["width"] = self.crop_width
        camera["height"] = self.crop_height

        return camera

# --- Usage ---
transform = A.Compose([
    RandomCropWithIntrinsics(crop_height=240, crop_width=320, p=1.0),
], seed=137)

image = np.zeros((480, 640, 3), dtype=np.uint8)
camera = {
    "fx": 500.0,
    "fy": 500.0,
    "cx": 320.0,
    "cy": 240.0,
    "width": 640,
    "height": 480,
}

result = transform(image=image, camera=camera)

cropped_image = result["image"]
updated_camera = result["camera"]

print(cropped_image.shape)
print(updated_camera)

If the same transform also resizes the crop, update the intrinsics for both operations: subtract the crop origin from cx and cy, then scale fx, fy, cx, and cy by the resize factors. Update width and height to the final output size.

When to Use Standard Targets, additional_targets, or Custom Routing

If an auxiliary array can be treated exactly like a standard target, prefer the standard target path. For example, if you have an image and a depth map but no separate segmentation mask, pass the depth map through the mask argument:

import cv2

transform = A.Compose([
    A.Resize(height=240, width=320, p=1.0),
], mask_interpolation=cv2.INTER_LINEAR, seed=137)

result = transform(image=image, mask=depth_map)
resized_depth = result["mask"]

Use additional_targets when the standard slot is already occupied or when preserving a separate semantic key is useful. For example, use additional_targets={"depth": "mask"} when you need to pass both mask=segmentation_mask and depth=depth_map through the same pipeline.

Use custom apply_to_<key> routing when the target has custom semantics: camera intrinsics, calibration records, affine matrices, coordinate-system metadata, crop provenance, invalid-value handling, confidence-specific behavior, or target-specific post-processing.

Using Custom Transforms in Pipelines

Once defined, your custom transform class is used just like any built-in Albumentations transform. Simply instantiate it and add it to your A.Compose list.


🔴 Advanced Topics

Reproducibility and Random Number Generation

To ensure your custom transforms produce reproducible results when a seed is set with A.Compose(..., seed=...), use the random number generators provided by the base transform classes: self.py_random and self.random_generator.

Using Python's global random module or NumPy's global np.random bypasses Albumentations' seeding mechanism. Fixing global seeds outside the pipeline does not guarantee reproducibility inside custom transforms.

Use:

  • self.py_random for standard Python-style sampling, for example self.py_random.uniform(a, b).
  • self.random_generator for NumPy-style sampling, for example self.random_generator.uniform(a, b, size).

For a comprehensive guide, see the Reproducibility Guide.

🟠 Input Parameter Validation with InitSchema

Use InitSchema to validate transform configuration at instantiation time. This is especially useful when parameters are mutually exclusive, constrained to ranges, or must satisfy relationships such as odd kernel sizes.

Define an inner class named InitSchema that inherits from pydantic.BaseModel. Albumentations validates the arguments before the transform is fully initialized.

Albumentations will automatically find and use this InitSchema to validate the arguments when your transform is instantiated.

Example: Mutually Exclusive Parameters

Let's create a hypothetical CustomBlur transform that can apply either a Gaussian blur (requiring sigma) or a Box blur (requiring kernel_size), but not both.

import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
import numpy as np
import cv2

from pydantic import BaseModel, Field, model_validator

class CustomBlur(ImageOnlyTransform):
    class InitSchema(BaseModel):
        sigma: float | None = Field(default=None, ge=0, description="Sigma for Gaussian blur.")
        kernel_size: int | None = Field(
            default=None,
            ge=3,
            description="Kernel size for Box blur (must be odd).",
        )

        @model_validator(mode="after")
        def check_blur_params(self):
            sigma_set = self.sigma is not None
            kernel_set = self.kernel_size is not None

            if sigma_set and kernel_set:
                raise ValueError("Specify either 'sigma' or 'kernel_size', not both.")
            if not sigma_set and not kernel_set:
                raise ValueError("Must specify either 'sigma' (for Gaussian) or 'kernel_size' (for Box).")
            if kernel_set and self.kernel_size % 2 == 0:
                raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")
            return self

    def __init__(self, sigma: float | None = None, kernel_size: int | None = None, p=0.5):
        super().__init__(p=p)
        self.sigma = sigma
        self.kernel_size = kernel_size

    def apply(self, img: np.ndarray, **params):
        if self.sigma is not None:
            sigma_safe = max(self.sigma, 1e-6)
            return cv2.GaussianBlur(img, ksize=(0, 0), sigmaX=sigma_safe)
        if self.kernel_size is not None:
            return cv2.blur(img, ksize=(self.kernel_size, self.kernel_size))
        return img

transform_gauss = CustomBlur(sigma=1.5, p=1.0)
transform_box = CustomBlur(kernel_size=5, p=1.0)

# These fail during initialization:
# CustomBlur()
# CustomBlur(sigma=1.5, kernel_size=5)
# CustomBlur(kernel_size=4)

This example uses @model_validator for rules that involve multiple fields. Use Field constraints for simple per-field validation.

🟠 Passing Arbitrary Data via targets_as_params

Standard Albumentations targets like image, mask, bboxes, keypoints are handled by the base transform classes. Sometimes you need extra data that is not a transformed output target but is needed to determine augmentation parameters.

Use targets_as_params when the data should influence parameter sampling. It does not automatically transform or return that data. If the same data also needs to be updated and returned, register it as a custom target with apply_to_<key> as shown above.

Examples include:

  • Overlay images/masks to be blended onto the main image.
  • Metadata associated with the image (e.g., timestamps, sensor readings).
  • Paths to auxiliary files.
  • Pre-computed model weights or features.

Albumentations provides this using the targets_as_params property.

  1. Override targets_as_params: In your custom transform class, override the targets_as_params property. It should return a list of strings. Each string is a key that you expect to find in the input data dictionary passed to the Compose pipeline.

    @property
    def targets_as_params(self):
        return ["my_custom_metadata_key", "overlay_image_data"]
    
  2. Pass Data to Compose: When calling the pipeline, include your custom data using the keys defined in targets_as_params.

    pipeline = A.Compose([
        MyCustomTransform(...),
        ...
    ])
    
    # Prepare your custom data
    custom_metadata = {"info": "some_value", "timestamp": 12345}
    overlay_data = {"image": overlay_img, "mask": overlay_mask}
    
    # Pass it during the call
    result = pipeline(
        image=main_image,
        mask=main_mask,
        my_custom_metadata_key=custom_metadata,
        overlay_image_data=overlay_data,
    )
    
  3. Access Data in get_params_dependent_on_data: The data you passed (associated with the keys listed in targets_as_params) will be available inside the data dictionary argument of your get_params_dependent_on_data method.

    def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]):
        # Access standard params like shape
        height, width = params["shape"]
    
        # Access your custom data
        custom_info = data["my_custom_metadata_key"]["info"]
        overlay_img = data["overlay_image_data"]["image"]
    
        # --- Now use this data to calculate augmentation parameters --- #
        # Example: Decide blur strength based on custom_info
        if custom_info == "high_detail":
            blur_sigma = self.py_random.uniform(0.1, 0.5)
        else:
            blur_sigma = self.py_random.uniform(1.0, 2.0)
    
        # Example: Process the overlay image (resize, calculate offset, etc.)
        # (See full OverlayElements example below)
        processed_overlay_params = self.process_overlay(overlay_img, (height, width))
    
        # Return calculated parameters to be used in apply methods
        return {
            "blur_sigma": blur_sigma,
            "processed_overlay": processed_overlay_params
        }
    
  4. Use Parameters in apply... Methods: The dictionary returned by get_params_dependent_on_data is passed via **params to your apply, apply_to_mask, apply_to_bboxes, etc., methods.

    def apply(self, img: np.ndarray, blur_sigma: float, processed_overlay: dict, **params):
        img = apply_blur(img, blur_sigma)
        return blend_overlay(img, processed_overlay["image"], processed_overlay["offset"])
    
    def apply_to_mask(self, mask: np.ndarray, processed_overlay: dict, **params):
        return apply_overlay_mask(mask, processed_overlay["mask"], processed_overlay["mask_id"])
    

Complete Example: AverageWithExternalImage Transform

The following AverageWithExternalImage transform demonstrates passing an external image via a custom key and averaging it with the main input image.

import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
import numpy as np
import cv2
from typing import Any

class AverageWithExternalImage(ImageOnlyTransform):
    """
    Averages the input image with an external image passed via `targets_as_params`.
    The external image is resized to match the input image dimensions.
    """

    def __init__(self, external_image_key: str = "external_image", p: float = 0.5):
        """
        Args:
            external_image_key (str): The key used to pass the external image data
                                     when calling the pipeline.
            p (float): Probability of applying the transform.
        """
        super().__init__(p=p)
        self.external_image_key = external_image_key

    @property
    def targets_as_params(self):
        return [self.external_image_key]

    def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]):
        external_image = data[self.external_image_key]
        if not isinstance(external_image, np.ndarray):
            raise TypeError(
                f"Expected '{self.external_image_key}' to be a NumPy ndarray, "
                f"got {type(external_image)}"
            )

        target_height, target_width = params["shape"][:2]
        resized_external_image = cv2.resize(
            external_image, (target_width, target_height), interpolation=cv2.INTER_LINEAR
        )

        return {"resized_external_image": resized_external_image}

    def apply(self, img: np.ndarray, resized_external_image: np.ndarray, **params: Any):
        if img.shape != resized_external_image.shape:
            raise ValueError(
                f"Shape mismatch between image {img.shape} and external image "
                f"{resized_external_image.shape}."
            )

        avg_image = cv2.addWeighted(img, 0.5, resized_external_image, 0.5, 0.0)
        return avg_image.astype(img.dtype)

# --- Example Usage ---
pipeline = A.Compose([
    AverageWithExternalImage(external_image_key="reference_photo", p=1.0),
    # Other transforms...
])

main_image = np.zeros((256, 256, 3), dtype=np.uint8)
reference_photo = np.full((128, 128, 3), 255, dtype=np.uint8)

augmented_data = pipeline(image=main_image, reference_photo=reference_photo)
averaged_image = augmented_data["image"]

Albumentations requires every key listed in targets_as_params to be present in the pipeline call. Use this mechanism when data should influence parameter sampling or intermediate computation; use custom target routing when that data should also be transformed and returned.

Where to Go Next?

Now that you can create your own custom transforms:

  • Integrate Your Transform: Add your custom transform to pipelines within the Basic Usage Guides relevant to your task.
  • Explore Base Class APIs: Consult the API Reference for details on ImageOnlyTransform, DualTransform, Transform3D, and other base classes you might inherit from.
  • Handle Advanced Scenarios: Learn how custom transforms interact with Additional Targets or how they can be included in Serialization.
  • Revisit Core Concepts: Ensure your custom transform correctly handles Targets and fits within the Pipeline structure.
  • Contribute (Optional): If your transform is broadly useful, consider contributing it back to the Albumentations library (see the project's contribution guidelines).