Functional transforms (augmentations.functional)¶
class MacenkoNormalizer
(angular_percentile=99)
[view source on GitHub] ¶
Macenko stain normalizer with optimized computations.
Interactive Tool Available!
Explore this transform visually and adjust parameters interactively using this tool:
Source code in albumentations/augmentations/functional.py
class MacenkoNormalizer(StainNormalizer):
"""Macenko stain normalizer with optimized computations."""
def __init__(self, angular_percentile: float = 99):
super().__init__()
self.angular_percentile = angular_percentile
def fit(self, img: np.ndarray, angular_percentile: float = 99) -> None:
"""Extract H&E stain matrix using optimized Macenko's method."""
# Step 1: Convert RGB to optical density (OD) space
optical_density = rgb_to_optical_density(img)
# Step 2: Remove background pixels
od_threshold = 0.05
threshold_mask = (optical_density > od_threshold).any(axis=1)
tissue_density = optical_density[threshold_mask]
if len(tissue_density) < 1:
raise ValueError(f"No tissue pixels found (threshold={od_threshold})")
# Step 3: Compute covariance matrix
tissue_density = np.ascontiguousarray(tissue_density, dtype=np.float32)
od_covariance = cv2.calcCovarMatrix(
tissue_density,
None,
cv2.COVAR_NORMAL | cv2.COVAR_ROWS | cv2.COVAR_SCALE,
)[0]
# Step 4: Get principal components
eigenvalues, eigenvectors = cv2.eigen(od_covariance)[1:]
idx = np.argsort(eigenvalues.ravel())[-2:]
principal_eigenvectors = np.ascontiguousarray(eigenvectors[:, idx], dtype=np.float32)
# Step 5: Project onto eigenvector plane
plane_coordinates = tissue_density @ principal_eigenvectors
# Step 6: Find angles of extreme points
polar_angles = np.arctan2(
plane_coordinates[:, 1],
plane_coordinates[:, 0],
)
# Get robust angle estimates
hematoxylin_angle = np.percentile(polar_angles, 100 - angular_percentile)
eosin_angle = np.percentile(polar_angles, angular_percentile)
# Step 7: Convert angles back to RGB space
hem_cos, hem_sin = np.cos(hematoxylin_angle), np.sin(hematoxylin_angle)
eos_cos, eos_sin = np.cos(eosin_angle), np.sin(eosin_angle)
angle_to_vector = np.array(
[[hem_cos, hem_sin], [eos_cos, eos_sin]],
dtype=np.float32,
)
stain_vectors = cv2.gemm(
angle_to_vector,
principal_eigenvectors.T,
1,
None,
0,
)
# Step 8: Ensure non-negativity by taking absolute values
# This is valid because stain vectors represent absorption coefficients
stain_vectors = np.abs(stain_vectors)
# Step 9: Normalize vectors to unit length
stain_vectors = stain_vectors / np.sqrt(np.sum(stain_vectors**2, axis=1, keepdims=True))
# Step 10: Order vectors as [hematoxylin, eosin]
# Hematoxylin typically has larger red component
self.stain_matrix_target = stain_vectors if stain_vectors[0, 0] > stain_vectors[1, 0] else stain_vectors[::-1]
class StainNormalizer
()
[view source on GitHub] ¶
Base class for stain normalizers.
Interactive Tool Available!
Explore this transform visually and adjust parameters interactively using this tool:
class VahadaneNormalizer
[view source on GitHub] ¶
Interactive Tool Available!
Explore this transform visually and adjust parameters interactively using this tool:
Source code in albumentations/augmentations/functional.py
class VahadaneNormalizer(StainNormalizer):
def fit(self, img: np.ndarray) -> None:
optical_density = rgb_to_optical_density(img)
nmf = SimpleNMF(n_iter=100)
_, stain_colors = nmf.fit_transform(optical_density)
# Use combined method for robust stain ordering
hematoxylin_idx, eosin_idx = order_stains_combined(stain_colors)
self.stain_matrix_target = np.array(
[
stain_colors[hematoxylin_idx],
stain_colors[eosin_idx],
],
)
def add_rain (img, slant, drop_length, drop_width, drop_color, blur_value, brightness_coefficient, rain_drops)
[view source on GitHub]¶
Optimized version using OpenCV line drawing.
Source code in albumentations/augmentations/functional.py
@uint8_io
@preserve_channel_dim
def add_rain(
img: np.ndarray,
slant: float,
drop_length: int,
drop_width: int,
drop_color: tuple[int, int, int],
blur_value: int,
brightness_coefficient: float,
rain_drops: np.ndarray,
) -> np.ndarray:
"""Optimized version using OpenCV line drawing."""
if not rain_drops.size:
return img.copy()
img = img.copy()
# Pre-allocate rain layer
rain_layer = np.zeros_like(img, dtype=np.uint8)
# Calculate end points correctly
end_points = rain_drops + np.array([[slant, drop_length]]) # This creates correct shape
# Stack arrays properly - both must be same shape arrays
lines = np.stack((rain_drops, end_points), axis=1) # Use tuple and proper axis
cv2.polylines(
rain_layer,
lines.astype(np.int32),
False,
drop_color,
drop_width,
lineType=cv2.LINE_4,
)
if blur_value > 1:
cv2.blur(rain_layer, (blur_value, blur_value), dst=rain_layer)
cv2.add(img, rain_layer, dst=img)
if brightness_coefficient != 1.0:
cv2.multiply(img, brightness_coefficient, dst=img, dtype=cv2.CV_8U)
return img
def add_shadow (img, vertices_list, intensities)
[view source on GitHub]¶
Add shadows to the image by reducing the intensity of the pixel values in specified regions.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image. Multichannel images are supported. |
vertices_list | list[np.ndarray] | List of vertices for shadow polygons. |
intensities | np.ndarray | Array of shadow intensities. Range is [0, 1]. |
Returns:
Type | Description |
---|---|
np.ndarray | Image with shadows added. |
Source code in albumentations/augmentations/functional.py
@uint8_io
@preserve_channel_dim
def add_shadow(
img: np.ndarray,
vertices_list: list[np.ndarray],
intensities: np.ndarray,
) -> np.ndarray:
"""Add shadows to the image by reducing the intensity of the pixel values in specified regions.
Args:
img (np.ndarray): Input image. Multichannel images are supported.
vertices_list (list[np.ndarray]): List of vertices for shadow polygons.
intensities (np.ndarray): Array of shadow intensities. Range is [0, 1].
Returns:
np.ndarray: Image with shadows added.
Reference:
https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
"""
num_channels = get_num_channels(img)
max_value = MAX_VALUES_BY_DTYPE[np.uint8]
img_shadowed = img.copy()
# Iterate over the vertices and intensity list
for vertices, shadow_intensity in zip(vertices_list, intensities):
# Create mask for the current shadow polygon
mask = np.zeros((img.shape[0], img.shape[1], 1), dtype=np.uint8)
cv2.fillPoly(mask, [vertices], (max_value,))
# Duplicate the mask to have the same number of channels as the image
mask = np.repeat(mask, num_channels, axis=2)
# Apply shadow to the channels directly
# It could be tempting to convert to HLS and apply the shadow to the L channel, but it creates artifacts
shadowed_indices = mask[:, :, 0] == max_value
darkness = 1 - shadow_intensity
img_shadowed[shadowed_indices] = clip(
img_shadowed[shadowed_indices] * darkness,
np.uint8,
inplace=True,
)
return img_shadowed
def add_snow_bleach (img, snow_point, brightness_coeff)
[view source on GitHub]¶
Adds a simple snow effect to the image by bleaching out pixels.
This function simulates a basic snow effect by increasing the brightness of pixels that are above a certain threshold (snow_point). It operates in the HLS color space to modify the lightness channel.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image. Can be either RGB uint8 or float32. |
snow_point | float | A float in the range [0, 1], scaled and adjusted to determine the threshold for pixel modification. Higher values result in less snow effect. |
brightness_coeff | float | Coefficient applied to increase the brightness of pixels below the snow_point threshold. Larger values lead to more pronounced snow effects. Should be greater than 1.0 for a visible effect. |
Returns:
Type | Description |
---|---|
np.ndarray | Image with simulated snow effect. The output has the same dtype as the input. |
Note
- This function converts the image to the HLS color space to modify the lightness channel.
- The snow effect is created by selectively increasing the brightness of pixels.
- This method tends to create a 'bleached' look, which may not be as realistic as more advanced snow simulation techniques.
- The function automatically handles both uint8 and float32 input images.
The snow effect is created through the following steps: 1. Convert the image from RGB to HLS color space. 2. Adjust the snow_point threshold. 3. Increase the lightness of pixels below the threshold. 4. Convert the image back to RGB.
Mathematical Formulation: Let L be the lightness channel in HLS space. For each pixel (i, j): If L[i, j] < snow_point: L[i, j] = L[i, j] * brightness_coeff
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
>>> snowy_image = A.functional.add_snow_v1(image, snow_point=0.5, brightness_coeff=1.5)
References
- HLS Color Space: https://en.wikipedia.org/wiki/HSL_and_HSV
- Original implementation: https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
Source code in albumentations/augmentations/functional.py
@uint8_io
def add_snow_bleach(
img: np.ndarray,
snow_point: float,
brightness_coeff: float,
) -> np.ndarray:
"""Adds a simple snow effect to the image by bleaching out pixels.
This function simulates a basic snow effect by increasing the brightness of pixels
that are above a certain threshold (snow_point). It operates in the HLS color space
to modify the lightness channel.
Args:
img (np.ndarray): Input image. Can be either RGB uint8 or float32.
snow_point (float): A float in the range [0, 1], scaled and adjusted to determine
the threshold for pixel modification. Higher values result in less snow effect.
brightness_coeff (float): Coefficient applied to increase the brightness of pixels
below the snow_point threshold. Larger values lead to more pronounced snow effects.
Should be greater than 1.0 for a visible effect.
Returns:
np.ndarray: Image with simulated snow effect. The output has the same dtype as the input.
Note:
- This function converts the image to the HLS color space to modify the lightness channel.
- The snow effect is created by selectively increasing the brightness of pixels.
- This method tends to create a 'bleached' look, which may not be as realistic as more
advanced snow simulation techniques.
- The function automatically handles both uint8 and float32 input images.
The snow effect is created through the following steps:
1. Convert the image from RGB to HLS color space.
2. Adjust the snow_point threshold.
3. Increase the lightness of pixels below the threshold.
4. Convert the image back to RGB.
Mathematical Formulation:
Let L be the lightness channel in HLS space.
For each pixel (i, j):
If L[i, j] < snow_point:
L[i, j] = L[i, j] * brightness_coeff
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
>>> snowy_image = A.functional.add_snow_v1(image, snow_point=0.5, brightness_coeff=1.5)
References:
- HLS Color Space: https://en.wikipedia.org/wiki/HSL_and_HSV
- Original implementation: https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
"""
max_value = MAX_VALUES_BY_DTYPE[np.uint8]
snow_point *= max_value / 2
snow_point += max_value / 3
image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
image_hls = np.array(image_hls, dtype=np.float32)
image_hls[:, :, 1][image_hls[:, :, 1] < snow_point] *= brightness_coeff
image_hls[:, :, 1] = clip(image_hls[:, :, 1], np.uint8, inplace=True)
image_hls = np.array(image_hls, dtype=np.uint8)
return cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)
def add_snow_texture (img, snow_point, brightness_coeff, snow_texture, sparkle_mask)
[view source on GitHub]¶
Add a realistic snow effect to the input image.
This function simulates snowfall by applying multiple visual effects to the image, including brightness adjustment, snow texture overlay, depth simulation, and color tinting. The result is a more natural-looking snow effect compared to simple pixel bleaching methods.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image in RGB format. |
snow_point | float | Coefficient that controls the amount and intensity of snow. Should be in the range [0, 1], where 0 means no snow and 1 means maximum snow effect. |
brightness_coeff | float | Coefficient for brightness adjustment to simulate the reflective nature of snow. Should be in the range [0, 1], where higher values result in a brighter image. |
snow_texture | np.ndarray | Snow texture. |
sparkle_mask | np.ndarray | Sparkle mask. |
Returns:
Type | Description |
---|---|
np.ndarray | Image with added snow effect. The output has the same dtype as the input. |
Note
- The function first converts the image to HSV color space for better control over brightness and color adjustments.
- A snow texture is generated using Gaussian noise and then filtered for a more natural appearance.
- A depth effect is simulated, with more snow at the top of the image and less at the bottom.
- A slight blue tint is added to simulate the cool color of snow.
- Random sparkle effects are added to simulate light reflecting off snow crystals.
The snow effect is created through the following steps: 1. Brightness adjustment in HSV space 2. Generation of a snow texture using Gaussian noise 3. Application of a depth effect to the snow texture 4. Blending of the snow texture with the original image 5. Addition of a cool blue tint 6. Addition of sparkle effects
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
>>> snowy_image = A.functional.add_snow_v2(image, snow_coeff=0.5, brightness_coeff=0.2)
Note
This function works with both uint8 and float32 image types, automatically handling the conversion between them.
References
- Perlin Noise: https://en.wikipedia.org/wiki/Perlin_noise
- HSV Color Space: https://en.wikipedia.org/wiki/HSL_and_HSV
Source code in albumentations/augmentations/functional.py
@uint8_io
def add_snow_texture(
img: np.ndarray,
snow_point: float,
brightness_coeff: float,
snow_texture: np.ndarray,
sparkle_mask: np.ndarray,
) -> np.ndarray:
"""Add a realistic snow effect to the input image.
This function simulates snowfall by applying multiple visual effects to the image,
including brightness adjustment, snow texture overlay, depth simulation, and color tinting.
The result is a more natural-looking snow effect compared to simple pixel bleaching methods.
Args:
img (np.ndarray): Input image in RGB format.
snow_point (float): Coefficient that controls the amount and intensity of snow.
Should be in the range [0, 1], where 0 means no snow and 1 means maximum snow effect.
brightness_coeff (float): Coefficient for brightness adjustment to simulate the
reflective nature of snow. Should be in the range [0, 1], where higher values
result in a brighter image.
snow_texture (np.ndarray): Snow texture.
sparkle_mask (np.ndarray): Sparkle mask.
Returns:
np.ndarray: Image with added snow effect. The output has the same dtype as the input.
Note:
- The function first converts the image to HSV color space for better control over
brightness and color adjustments.
- A snow texture is generated using Gaussian noise and then filtered for a more
natural appearance.
- A depth effect is simulated, with more snow at the top of the image and less at the bottom.
- A slight blue tint is added to simulate the cool color of snow.
- Random sparkle effects are added to simulate light reflecting off snow crystals.
The snow effect is created through the following steps:
1. Brightness adjustment in HSV space
2. Generation of a snow texture using Gaussian noise
3. Application of a depth effect to the snow texture
4. Blending of the snow texture with the original image
5. Addition of a cool blue tint
6. Addition of sparkle effects
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
>>> snowy_image = A.functional.add_snow_v2(image, snow_coeff=0.5, brightness_coeff=0.2)
Note:
This function works with both uint8 and float32 image types, automatically
handling the conversion between them.
References:
- Perlin Noise: https://en.wikipedia.org/wiki/Perlin_noise
- HSV Color Space: https://en.wikipedia.org/wiki/HSL_and_HSV
"""
max_value = MAX_VALUES_BY_DTYPE[np.uint8]
# Convert to HSV for better color control
img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32)
# Increase brightness
img_hsv[:, :, 2] = np.clip(
img_hsv[:, :, 2] * (1 + brightness_coeff * snow_point),
0,
max_value,
)
# Generate snow texture
snow_texture = cv2.GaussianBlur(snow_texture, (0, 0), sigmaX=1, sigmaY=1)
# Create depth effect for snow simulation
# More snow accumulates at the top of the image, gradually decreasing towards the bottom
# This simulates natural snow distribution on surfaces
# The effect is achieved using a linear gradient from 1 (full snow) to 0.2 (less snow)
rows = img.shape[0]
depth_effect = np.linspace(1, 0.2, rows)[:, np.newaxis]
snow_texture *= depth_effect
# Apply snow texture
snow_layer = (np.dstack([snow_texture] * 3) * max_value * snow_point).astype(
np.float32,
)
# Blend snow with original image
img_with_snow = cv2.add(img_hsv, snow_layer)
# Add a slight blue tint to simulate cool snow color
blue_tint = np.full_like(img_with_snow, (0.6, 0.75, 1)) # Slight blue in HSV
img_with_snow = cv2.addWeighted(
img_with_snow,
0.85,
blue_tint,
0.15 * snow_point,
0,
)
# Convert back to RGB
img_with_snow = cv2.cvtColor(img_with_snow.astype(np.uint8), cv2.COLOR_HSV2RGB)
# Add some sparkle effects for snow glitter
img_with_snow[sparkle_mask] = [max_value, max_value, max_value]
return img_with_snow
def add_sun_flare_overlay (img, flare_center, src_radius, src_color, circles)
[view source on GitHub]¶
Add a sun flare effect to an image using a simple overlay technique.
This function creates a basic sun flare effect by overlaying multiple semi-transparent circles of varying sizes and intensities on the input image. The effect simulates a simple lens flare caused by bright light sources.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | The input image. |
flare_center | tuple[float, float] | (x, y) coordinates of the flare center in pixel coordinates. |
src_radius | int | The radius of the main sun circle in pixels. |
src_color | tuple[int, ...] | The color of the sun, represented as a tuple of RGB values. |
circles | list[Any] | A list of tuples, each representing a circle that contributes to the flare effect. Each tuple contains: - alpha (float): The transparency of the circle (0.0 to 1.0). - center (tuple[int, int]): (x, y) coordinates of the circle center. - radius (int): The radius of the circle. - color (tuple[int, int, int]): RGB color of the circle. |
Returns:
Type | Description |
---|---|
np.ndarray | The output image with the sun flare effect added. |
Note
- This function uses a simple alpha blending technique to overlay flare elements.
- The main sun is created as a gradient circle, fading from the center outwards.
- Additional flare circles are added along an imaginary line from the sun's position.
- This method is computationally efficient but may produce less realistic results compared to more advanced techniques.
The flare effect is created through the following steps: 1. Create an overlay image and output image as copies of the input. 2. Add smaller flare circles to the overlay. 3. Blend the overlay with the output image using alpha compositing. 4. Add the main sun circle with a radial gradient.
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
>>> flare_center = (50, 50)
>>> src_radius = 20
>>> src_color = (255, 255, 200)
>>> circles = [
... (0.1, (60, 60), 5, (255, 200, 200)),
... (0.2, (70, 70), 3, (200, 255, 200))
... ]
>>> flared_image = A.functional.add_sun_flare_overlay(
... image, flare_center, src_radius, src_color, circles
... )
References
- Alpha compositing: https://en.wikipedia.org/wiki/Alpha_compositing
- Lens flare: https://en.wikipedia.org/wiki/Lens_flare
Source code in albumentations/augmentations/functional.py
@uint8_io
@preserve_channel_dim
@maybe_process_in_chunks
def add_sun_flare_overlay(
img: np.ndarray,
flare_center: tuple[float, float],
src_radius: int,
src_color: tuple[int, ...],
circles: list[Any],
) -> np.ndarray:
"""Add a sun flare effect to an image using a simple overlay technique.
This function creates a basic sun flare effect by overlaying multiple semi-transparent
circles of varying sizes and intensities on the input image. The effect simulates
a simple lens flare caused by bright light sources.
Args:
img (np.ndarray): The input image.
flare_center (tuple[float, float]): (x, y) coordinates of the flare center
in pixel coordinates.
src_radius (int): The radius of the main sun circle in pixels.
src_color (tuple[int, ...]): The color of the sun, represented as a tuple of RGB values.
circles (list[Any]): A list of tuples, each representing a circle that contributes
to the flare effect. Each tuple contains:
- alpha (float): The transparency of the circle (0.0 to 1.0).
- center (tuple[int, int]): (x, y) coordinates of the circle center.
- radius (int): The radius of the circle.
- color (tuple[int, int, int]): RGB color of the circle.
Returns:
np.ndarray: The output image with the sun flare effect added.
Note:
- This function uses a simple alpha blending technique to overlay flare elements.
- The main sun is created as a gradient circle, fading from the center outwards.
- Additional flare circles are added along an imaginary line from the sun's position.
- This method is computationally efficient but may produce less realistic results
compared to more advanced techniques.
The flare effect is created through the following steps:
1. Create an overlay image and output image as copies of the input.
2. Add smaller flare circles to the overlay.
3. Blend the overlay with the output image using alpha compositing.
4. Add the main sun circle with a radial gradient.
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
>>> flare_center = (50, 50)
>>> src_radius = 20
>>> src_color = (255, 255, 200)
>>> circles = [
... (0.1, (60, 60), 5, (255, 200, 200)),
... (0.2, (70, 70), 3, (200, 255, 200))
... ]
>>> flared_image = A.functional.add_sun_flare_overlay(
... image, flare_center, src_radius, src_color, circles
... )
References:
- Alpha compositing: https://en.wikipedia.org/wiki/Alpha_compositing
- Lens flare: https://en.wikipedia.org/wiki/Lens_flare
"""
overlay = img.copy()
output = img.copy()
weighted_brightness = 0.0
total_radius_length = 0.0
for alpha, (x, y), rad3, circle_color in circles:
weighted_brightness += alpha * rad3
total_radius_length += rad3
cv2.circle(overlay, (x, y), rad3, circle_color, -1)
output = add_weighted(overlay, alpha, output, 1 - alpha)
point = [int(x) for x in flare_center]
overlay = output.copy()
num_times = src_radius // 10
# max_alpha is calculated using weighted_brightness and total_radii_length times 5
# meaning the higher the alpha with larger area, the brighter the bright spot will be
# for list of alphas in range [0.05, 0.2], the max_alpha should below 1
max_alpha = weighted_brightness / total_radius_length * 5
alpha = np.linspace(0.0, min(max_alpha, 1.0), num=num_times)
rad = np.linspace(1, src_radius, num=num_times)
for i in range(num_times):
cv2.circle(overlay, point, int(rad[i]), src_color, -1)
alp = alpha[num_times - i - 1] * alpha[num_times - i - 1] * alpha[num_times - i - 1]
output = add_weighted(overlay, alp, output, 1 - alp)
return output
def add_sun_flare_physics_based (img, flare_center, src_radius, src_color, circles)
[view source on GitHub]¶
Add a more realistic sun flare effect to the image.
This function creates a complex sun flare effect by simulating various optical phenomena that occur in real camera lenses when capturing bright light sources. The result is a more realistic and physically plausible lens flare effect.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image. |
flare_center | tuple[int, int] | (x, y) coordinates of the sun's center in pixels. |
src_radius | int | Radius of the main sun circle in pixels. |
src_color | tuple[int, int, int] | Color of the sun in RGB format. |
circles | list[Any] | List of tuples, each representing a flare circle with parameters: (alpha, center, size, color) - alpha (float): Transparency of the circle (0.0 to 1.0). - center (tuple[int, int]): (x, y) coordinates of the circle center. - size (float): Size factor for the circle radius. - color (tuple[int, int, int]): RGB color of the circle. |
Returns:
Type | Description |
---|---|
np.ndarray | Image with added sun flare effect. |
Note
This function implements several techniques to create a more realistic flare: 1. Separate flare layer: Allows for complex manipulations of the flare effect. 2. Lens diffraction spikes: Simulates light diffraction in camera aperture. 3. Radial gradient mask: Creates natural fading of the flare from the center. 4. Gaussian blur: Softens the flare for a more natural glow effect. 5. Chromatic aberration: Simulates color fringing often seen in real lens flares. 6. Screen blending: Provides a more realistic blending of the flare with the image.
The flare effect is created through the following steps: 1. Create a separate flare layer. 2. Add the main sun circle and diffraction spikes to the flare layer. 3. Add additional flare circles based on the input parameters. 4. Apply Gaussian blur to soften the flare. 5. Create and apply a radial gradient mask for natural fading. 6. Simulate chromatic aberration by applying different blurs to color channels. 7. Blend the flare with the original image using screen blending mode.
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [1000, 1000, 3], dtype=np.uint8)
>>> flare_center = (500, 500)
>>> src_radius = 50
>>> src_color = (255, 255, 200)
>>> circles = [
... (0.1, (550, 550), 10, (255, 200, 200)),
... (0.2, (600, 600), 5, (200, 255, 200))
... ]
>>> flared_image = A.functional.add_sun_flare_physics_based(
... image, flare_center, src_radius, src_color, circles
... )
References
- Lens flare: https://en.wikipedia.org/wiki/Lens_flare
- Diffraction: https://en.wikipedia.org/wiki/Diffraction
- Chromatic aberration: https://en.wikipedia.org/wiki/Chromatic_aberration
- Screen blending: https://en.wikipedia.org/wiki/Blend_modes#Screen
Source code in albumentations/augmentations/functional.py
@uint8_io
@clipped
def add_sun_flare_physics_based(
img: np.ndarray,
flare_center: tuple[int, int],
src_radius: int,
src_color: tuple[int, int, int],
circles: list[Any],
) -> np.ndarray:
"""Add a more realistic sun flare effect to the image.
This function creates a complex sun flare effect by simulating various optical phenomena
that occur in real camera lenses when capturing bright light sources. The result is a
more realistic and physically plausible lens flare effect.
Args:
img (np.ndarray): Input image.
flare_center (tuple[int, int]): (x, y) coordinates of the sun's center in pixels.
src_radius (int): Radius of the main sun circle in pixels.
src_color (tuple[int, int, int]): Color of the sun in RGB format.
circles (list[Any]): List of tuples, each representing a flare circle with parameters:
(alpha, center, size, color)
- alpha (float): Transparency of the circle (0.0 to 1.0).
- center (tuple[int, int]): (x, y) coordinates of the circle center.
- size (float): Size factor for the circle radius.
- color (tuple[int, int, int]): RGB color of the circle.
Returns:
np.ndarray: Image with added sun flare effect.
Note:
This function implements several techniques to create a more realistic flare:
1. Separate flare layer: Allows for complex manipulations of the flare effect.
2. Lens diffraction spikes: Simulates light diffraction in camera aperture.
3. Radial gradient mask: Creates natural fading of the flare from the center.
4. Gaussian blur: Softens the flare for a more natural glow effect.
5. Chromatic aberration: Simulates color fringing often seen in real lens flares.
6. Screen blending: Provides a more realistic blending of the flare with the image.
The flare effect is created through the following steps:
1. Create a separate flare layer.
2. Add the main sun circle and diffraction spikes to the flare layer.
3. Add additional flare circles based on the input parameters.
4. Apply Gaussian blur to soften the flare.
5. Create and apply a radial gradient mask for natural fading.
6. Simulate chromatic aberration by applying different blurs to color channels.
7. Blend the flare with the original image using screen blending mode.
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, [1000, 1000, 3], dtype=np.uint8)
>>> flare_center = (500, 500)
>>> src_radius = 50
>>> src_color = (255, 255, 200)
>>> circles = [
... (0.1, (550, 550), 10, (255, 200, 200)),
... (0.2, (600, 600), 5, (200, 255, 200))
... ]
>>> flared_image = A.functional.add_sun_flare_physics_based(
... image, flare_center, src_radius, src_color, circles
... )
References:
- Lens flare: https://en.wikipedia.org/wiki/Lens_flare
- Diffraction: https://en.wikipedia.org/wiki/Diffraction
- Chromatic aberration: https://en.wikipedia.org/wiki/Chromatic_aberration
- Screen blending: https://en.wikipedia.org/wiki/Blend_modes#Screen
"""
output = img.copy()
height, width = img.shape[:2]
# Create a separate flare layer
flare_layer = np.zeros_like(img, dtype=np.float32)
# Add the main sun
cv2.circle(flare_layer, flare_center, src_radius, src_color, -1)
# Add lens diffraction spikes
for angle in [0, 45, 90, 135]:
end_point = (
int(flare_center[0] + np.cos(np.radians(angle)) * max(width, height)),
int(flare_center[1] + np.sin(np.radians(angle)) * max(width, height)),
)
cv2.line(flare_layer, flare_center, end_point, src_color, 2)
# Add flare circles
for _, center, size, color in circles:
cv2.circle(flare_layer, center, int(size**0.33), color, -1)
# Apply gaussian blur to soften the flare
flare_layer = cv2.GaussianBlur(flare_layer, (0, 0), sigmaX=15, sigmaY=15)
# Create a radial gradient mask
y, x = np.ogrid[:height, :width]
mask = np.sqrt((x - flare_center[0]) ** 2 + (y - flare_center[1]) ** 2)
mask = 1 - np.clip(mask / (max(width, height) * 0.7), 0, 1)
mask = np.dstack([mask] * 3)
# Apply the mask to the flare layer
flare_layer *= mask
# Add chromatic aberration
channels = list(cv2.split(flare_layer))
channels[0] = cv2.GaussianBlur(
channels[0],
(0, 0),
sigmaX=3,
sigmaY=3,
) # Blue channel
channels[2] = cv2.GaussianBlur(
channels[2],
(0, 0),
sigmaX=5,
sigmaY=5,
) # Red channel
flare_layer = cv2.merge(channels)
# Blend the flare with the original image using screen blending
return 255 - ((255 - output) * (255 - flare_layer) / 255)
def apply_corner_illumination (img, intensity, corner)
[view source on GitHub]¶
Apply corner-based illumination effect.
Source code in albumentations/augmentations/functional.py
@clipped
def apply_corner_illumination(
img: np.ndarray,
intensity: float,
corner: Literal[0, 1, 2, 3],
) -> np.ndarray:
"""Apply corner-based illumination effect."""
if intensity == 0:
return img.copy()
height, width = img.shape[:2]
# Pre-compute diagonal length once
diagonal_length = math.sqrt(height * height + width * width)
# Create inverted distance map mask directly
# Use uint8 for distanceTransform regardless of input dtype
mask = np.full((height, width), 255, dtype=np.uint8)
# Use array indexing instead of conditionals
corners = [(0, 0), (0, width - 1), (height - 1, width - 1), (height - 1, 0)]
mask[corners[corner]] = 0
# Calculate distance transform
pattern = cv2.distanceTransform(
mask,
distanceType=cv2.DIST_L2,
maskSize=cv2.DIST_MASK_PRECISE,
dstType=cv2.CV_32F, # Specify float output directly
)
# Combine operations to reduce array copies
cv2.multiply(pattern, -intensity / diagonal_length, dst=pattern)
cv2.add(pattern, 1, dst=pattern)
if img.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
pattern = cv2.merge([pattern] * img.shape[2])
return multiply_by_array(img, pattern)
def apply_gaussian_illumination (img, intensity, center, sigma)
[view source on GitHub]¶
Apply gaussian illumination effect.
Source code in albumentations/augmentations/functional.py
@clipped
def apply_gaussian_illumination(
img: np.ndarray,
intensity: float,
center: tuple[float, float],
sigma: float,
) -> np.ndarray:
"""Apply gaussian illumination effect."""
if intensity == 0:
return img.copy()
height, width = img.shape[:2]
# Pre-compute constants
center_x = width * center[0]
center_y = height * center[1]
sigma2 = 2 * (max(height, width) * sigma) ** 2 # Pre-compute denominator
# Create coordinate grid and calculate distances in-place
y, x = np.ogrid[:height, :width]
x = x.astype(np.float32)
y = y.astype(np.float32)
x -= center_x
y -= center_y
# Calculate squared distances in-place
cv2.multiply(x, x, dst=x)
cv2.multiply(y, y, dst=y)
x = x + y
# Calculate gaussian directly into x array
cv2.multiply(x, -1 / sigma2, dst=x)
cv2.exp(x, dst=x)
# Scale by intensity
cv2.multiply(x, intensity, dst=x)
cv2.add(x, 1, dst=x)
if img.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
x = cv2.merge([x] * img.shape[2])
return multiply_by_array(img, x)
def apply_linear_illumination (img, intensity, angle)
[view source on GitHub]¶
Apply directional illumination effect to an image using a linear gradient.
The function creates a directional gradient and uses it to modulate image brightness. The gradient direction is controlled by the angle parameter, and the strength of the effect is controlled by the intensity parameter.
The illumination is applied by multiplying the image with a scale factor that varies linearly across the image. The scale factor ranges from (1-|intensity|) to (1+|intensity|).
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image in range [0, 1]. Can be single or multi-channel. |
intensity | float | Strength and direction of the illumination effect, range [-1, 1]. - Positive values brighten in gradient direction - Negative values darken in gradient direction - Magnitude determines strength of the effect |
angle | float | Direction of the gradient in degrees. - 0: left to right - 90: bottom to top - 180: right to left - 270: top to bottom |
Returns:
Type | Description |
---|---|
np.ndarray | Image with applied illumination effect, same shape and range as input. |
Implementation details: 1. Creates a directional gradient in range [0, 1] 2. For negative intensity, inverts the gradient (1 - gradient) 3. For multi-channel images, repeats gradient across channels 4. Computes scale factor in-place: scale = 1 - |intensity| + 2 * |intensity| * gradient This maps gradient [0, 1] to scale [(1-|i|), (1+|i|)] 5. Multiplies image by scale factor
Note
Uses in-place operations where possible for memory efficiency. The @float32_io decorator ensures float32 precision. The @clipped decorator ensures output values stay in valid range.
Source code in albumentations/augmentations/functional.py
@float32_io
def apply_linear_illumination(img: np.ndarray, intensity: float, angle: float) -> np.ndarray:
"""Apply directional illumination effect to an image using a linear gradient.
The function creates a directional gradient and uses it to modulate image brightness.
The gradient direction is controlled by the angle parameter, and the strength of the
effect is controlled by the intensity parameter.
The illumination is applied by multiplying the image with a scale factor that varies
linearly across the image. The scale factor ranges from (1-|intensity|) to (1+|intensity|).
Args:
img: Input image in range [0, 1]. Can be single or multi-channel.
intensity: Strength and direction of the illumination effect, range [-1, 1].
- Positive values brighten in gradient direction
- Negative values darken in gradient direction
- Magnitude determines strength of the effect
angle: Direction of the gradient in degrees.
- 0: left to right
- 90: bottom to top
- 180: right to left
- 270: top to bottom
Returns:
Image with applied illumination effect, same shape and range as input.
Implementation details:
1. Creates a directional gradient in range [0, 1]
2. For negative intensity, inverts the gradient (1 - gradient)
3. For multi-channel images, repeats gradient across channels
4. Computes scale factor in-place:
scale = 1 - |intensity| + 2 * |intensity| * gradient
This maps gradient [0, 1] to scale [(1-|i|), (1+|i|)]
5. Multiplies image by scale factor
Note:
Uses in-place operations where possible for memory efficiency.
The @float32_io decorator ensures float32 precision.
The @clipped decorator ensures output values stay in valid range.
"""
height, width = img.shape[:2]
abs_intensity = abs(intensity)
# Create gradient and handle negative intensity in one step
gradient = create_directional_gradient(height, width, angle)
if intensity < 0:
cv2.subtract(1, gradient, dst=gradient)
cv2.multiply(gradient, 2 * abs_intensity, dst=gradient)
cv2.add(gradient, 1 - abs_intensity, dst=gradient)
# Add channel dimension if needed
if img.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
gradient = gradient[..., np.newaxis]
return multiply_by_array(img, gradient)
def apply_plasma_brightness_contrast (img, brightness_factor, contrast_factor, plasma_pattern)
[view source on GitHub]¶
Apply plasma-based brightness and contrast adjustments.
Source code in albumentations/augmentations/functional.py
@clipped
@float32_io
def apply_plasma_brightness_contrast(
img: np.ndarray,
brightness_factor: float,
contrast_factor: float,
plasma_pattern: np.ndarray,
) -> np.ndarray:
"""Apply plasma-based brightness and contrast adjustments."""
# Early return if no adjustments needed
if brightness_factor == 0 and contrast_factor == 0:
return img
img = img.copy()
# Expand plasma pattern once if needed
if img.ndim > MONO_CHANNEL_DIMENSIONS:
plasma_pattern = np.tile(plasma_pattern[..., np.newaxis], (1, 1, img.shape[-1]))
# Apply brightness adjustment
if brightness_factor != 0:
brightness_adjustment = multiply(plasma_pattern, brightness_factor, inplace=False)
img = add(img, brightness_adjustment, inplace=True)
# Apply contrast adjustment
if contrast_factor != 0:
mean = img.mean()
contrast_weights = multiply(plasma_pattern, contrast_factor, inplace=False) + 1
img = multiply(img, contrast_weights, inplace=True)
mean_factor = mean * (1.0 - contrast_weights)
return add(img, mean_factor, inplace=True)
return img
def apply_plasma_shadow (img, intensity, plasma_pattern)
[view source on GitHub]¶
Apply plasma-based shadow effect by darkening.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image |
intensity | float | Shadow intensity in [0, 1] |
plasma_pattern | np.ndarray | Generated plasma pattern of shape (H, W) |
Returns:
Type | Description |
---|---|
np.ndarray | Image with applied shadow effect |
Source code in albumentations/augmentations/functional.py
@clipped
def apply_plasma_shadow(
img: np.ndarray,
intensity: float,
plasma_pattern: np.ndarray,
) -> np.ndarray:
"""Apply plasma-based shadow effect by darkening.
Args:
img: Input image
intensity: Shadow intensity in [0, 1]
plasma_pattern: Generated plasma pattern of shape (H, W)
Returns:
Image with applied shadow effect
"""
# Scale plasma pattern by intensity first (scalar operation)
scaled_pattern = plasma_pattern * intensity
# Expand dimensions only once if needed
if img.ndim > MONO_CHANNEL_DIMENSIONS:
scaled_pattern = scaled_pattern[..., np.newaxis]
# Single multiply operation
return img * (1 - scaled_pattern)
def apply_salt_and_pepper (img, salt_mask, pepper_mask)
[view source on GitHub]¶
Apply salt and pepper noise to image using pre-computed masks.
Source code in albumentations/augmentations/functional.py
def apply_salt_and_pepper(
img: np.ndarray,
salt_mask: np.ndarray,
pepper_mask: np.ndarray,
) -> np.ndarray:
"""Apply salt and pepper noise to image using pre-computed masks."""
# Add channel dimension to masks if image is 3D
if img.ndim == 3:
salt_mask = salt_mask[..., None]
pepper_mask = pepper_mask[..., None]
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
return np.where(salt_mask, max_value, np.where(pepper_mask, 0, img))
def auto_contrast (img, cutoff, ignore, method)
[view source on GitHub]¶
Apply auto contrast to the image.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image in uint8 or float32 format. |
cutoff | float | Percentage of pixels to cut off from the histogram edges. Range: 0-100. Default: 0 (no cutoff) |
ignore | int | None | Pixel value to ignore in auto contrast calculation. Useful for handling alpha channels or other special values. |
method | Literal['cdf', 'pil'] | Method to use for contrast enhancement: - "cdf": Uses cumulative distribution function (original albumentations method) - "pil": Uses linear scaling like PIL.ImageOps.autocontrast |
Returns:
Type | Description |
---|---|
np.ndarray | Contrast-enhanced image in the same dtype as input. |
Source code in albumentations/augmentations/functional.py
@uint8_io
def auto_contrast(
img: np.ndarray,
cutoff: float,
ignore: int | None,
method: Literal["cdf", "pil"],
) -> np.ndarray:
"""Apply auto contrast to the image.
Args:
img: Input image in uint8 or float32 format.
cutoff: Percentage of pixels to cut off from the histogram edges.
Range: 0-100. Default: 0 (no cutoff)
ignore: Pixel value to ignore in auto contrast calculation.
Useful for handling alpha channels or other special values.
method: Method to use for contrast enhancement:
- "cdf": Uses cumulative distribution function (original albumentations method)
- "pil": Uses linear scaling like PIL.ImageOps.autocontrast
Returns:
Contrast-enhanced image in the same dtype as input.
"""
result = img.copy()
num_channels = get_num_channels(img)
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
# Pre-compute histograms using cv2.calcHist - much faster than np.histogram
if img.ndim > MONO_CHANNEL_DIMENSIONS:
channels = cv2.split(img)
hists: list[np.ndarray] = []
for i, channel in enumerate(channels):
if ignore is not None and i == ignore:
hists.append(None)
continue
mask = None if ignore is None else (channel != ignore)
hist = cv2.calcHist([channel], [0], mask, [256], [0, max_value])
hists.append(hist.ravel())
for i in range(num_channels):
if ignore is not None and i == ignore:
continue
if img.ndim > MONO_CHANNEL_DIMENSIONS:
hist = hists[i]
channel = channels[i]
else:
mask = None if ignore is None else (img != ignore)
hist = cv2.calcHist([img], [0], mask, [256], [0, max_value]).ravel()
channel = img
lo, hi = get_histogram_bounds(hist, cutoff)
if hi <= lo:
continue
lut = create_contrast_lut(hist, lo, hi, max_value, method)
if ignore is not None:
lut[ignore] = ignore
if img.ndim > MONO_CHANNEL_DIMENSIONS:
result[..., i] = sz_lut(channel, lut)
else:
result = sz_lut(channel, lut)
return result
def clahe (img, clip_limit, tile_grid_size)
[view source on GitHub]¶
Apply Contrast Limited Adaptive Histogram Equalization (CLAHE) to the input image.
This function enhances the contrast of the input image using CLAHE. For color images, it converts the image to the LAB color space, applies CLAHE to the L channel, and then converts the image back to RGB.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image. Can be grayscale (2D array) or RGB (3D array). |
clip_limit | float | Threshold for contrast limiting. Higher values give more contrast. |
tile_grid_size | tuple[int, int] | Size of grid for histogram equalization. Width and height of the grid. |
Returns:
Type | Description |
---|---|
np.ndarray | Image with CLAHE applied. The output has the same dtype as the input. |
Note
- If the input image is float32, it's temporarily converted to uint8 for processing and then converted back to float32.
- For color images, CLAHE is applied only to the luminance channel in the LAB color space.
Exceptions:
Type | Description |
---|---|
ValueError | If the input image is not 2D or 3D. |
Examples:
>>> import numpy as np
>>> img = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> result = clahe(img, clip_limit=2.0, tile_grid_size=(8, 8))
>>> assert result.shape == img.shape
>>> assert result.dtype == img.dtype
Source code in albumentations/augmentations/functional.py
@uint8_io
@preserve_channel_dim
def clahe(
img: np.ndarray,
clip_limit: float,
tile_grid_size: tuple[int, int],
) -> np.ndarray:
"""Apply Contrast Limited Adaptive Histogram Equalization (CLAHE) to the input image.
This function enhances the contrast of the input image using CLAHE. For color images,
it converts the image to the LAB color space, applies CLAHE to the L channel, and then
converts the image back to RGB.
Args:
img (np.ndarray): Input image. Can be grayscale (2D array) or RGB (3D array).
clip_limit (float): Threshold for contrast limiting. Higher values give more contrast.
tile_grid_size (tuple[int, int]): Size of grid for histogram equalization.
Width and height of the grid.
Returns:
np.ndarray: Image with CLAHE applied. The output has the same dtype as the input.
Note:
- If the input image is float32, it's temporarily converted to uint8 for processing
and then converted back to float32.
- For color images, CLAHE is applied only to the luminance channel in the LAB color space.
Raises:
ValueError: If the input image is not 2D or 3D.
Example:
>>> import numpy as np
>>> img = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> result = clahe(img, clip_limit=2.0, tile_grid_size=(8, 8))
>>> assert result.shape == img.shape
>>> assert result.dtype == img.dtype
"""
img = img.copy()
clahe_mat = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
if is_grayscale_image(img):
return clahe_mat.apply(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
img[:, :, 0] = clahe_mat.apply(img[:, :, 0])
return cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
def create_contrast_lut (hist, min_intensity, max_intensity, max_value, method)
[view source on GitHub]¶
Create lookup table for contrast adjustment.
Source code in albumentations/augmentations/functional.py
def create_contrast_lut(
hist: np.ndarray,
min_intensity: int,
max_intensity: int,
max_value: int,
method: Literal["cdf", "pil"],
) -> np.ndarray:
"""Create lookup table for contrast adjustment."""
# Handle single intensity case
if min_intensity >= max_intensity:
return np.zeros(256, dtype=np.uint8)
if method == "cdf":
hist_range = hist[min_intensity : max_intensity + 1]
cdf = hist_range.cumsum()
if cdf[-1] == 0: # No valid pixels
return np.arange(256, dtype=np.uint8)
# Normalize CDF to full range
cdf = (cdf - cdf[0]) * max_value / (cdf[-1] - cdf[0])
# Create lookup table
lut = np.zeros(256, dtype=np.uint8)
lut[min_intensity : max_intensity + 1] = np.clip(np.round(cdf), 0, max_value).astype(np.uint8)
lut[max_intensity + 1 :] = max_value
return lut
# "pil" method
scale = max_value / (max_intensity - min_intensity)
indices = np.arange(256, dtype=float)
# Changed: Use np.round to get 128 for middle value
# Test expects [0, 128, 255] for range [0, 2]
lut = np.clip(np.round((indices - min_intensity) * scale), 0, max_value).astype(np.uint8)
lut[:min_intensity] = 0
lut[max_intensity + 1 :] = max_value
return lut
def create_directional_gradient (height, width, angle)
[view source on GitHub]¶
Create a directional gradient in [0, 1] range.
Optimized implementation using broadcasting and fast paths for common angles: - 0°, 180°: horizontal gradients using single linspace - 90°, 270°: vertical gradients using single linspace - 45°, 135°, 225°, 315°: diagonal gradients using equal combinations of horizontal and vertical - Other angles: computed using trigonometric functions
Source code in albumentations/augmentations/functional.py
def create_directional_gradient(height: int, width: int, angle: float) -> np.ndarray:
"""Create a directional gradient in [0, 1] range.
Optimized implementation using broadcasting and fast paths for common angles:
- 0°, 180°: horizontal gradients using single linspace
- 90°, 270°: vertical gradients using single linspace
- 45°, 135°, 225°, 315°: diagonal gradients using equal combinations of horizontal and vertical
- Other angles: computed using trigonometric functions
"""
# Fast path for horizontal gradients
if angle == 0:
return np.linspace(0, 1, width, dtype=np.float32)[None, :] * np.ones((height, 1), dtype=np.float32)
if angle == 180:
return np.linspace(1, 0, width, dtype=np.float32)[None, :] * np.ones((height, 1), dtype=np.float32)
# Fast path for vertical gradients
if angle == 90:
return np.linspace(0, 1, height, dtype=np.float32)[:, None] * np.ones((1, width), dtype=np.float32)
if angle == 270:
return np.linspace(1, 0, height, dtype=np.float32)[:, None] * np.ones((1, width), dtype=np.float32)
# Fast path for diagonal gradients using broadcasting
if angle in (45, 135, 225, 315):
x = np.linspace(0, 1, width, dtype=np.float32)[None, :] # Horizontal
y = np.linspace(0, 1, height, dtype=np.float32)[:, None] # Vertical
if angle == 45: # Bottom-left to top-right
return cv2.normalize(x + y, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
if angle == 135: # Bottom-right to top-left
return cv2.normalize((1 - x) + y, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
if angle == 225: # Top-right to bottom-left
return cv2.normalize((1 - x) + (1 - y), None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
# angle == 315: # Top-left to bottom-right
return cv2.normalize(x + (1 - y), None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
# General case for arbitrary angles using broadcasting
y = np.linspace(0, 1, height, dtype=np.float32)[:, None] # Column vector
x = np.linspace(0, 1, width, dtype=np.float32)[None, :] # Row vector
angle_rad = np.deg2rad(angle)
cos_a = math.cos(angle_rad)
sin_a = math.sin(angle_rad)
cv2.multiply(x, cos_a, dst=x)
cv2.multiply(y, sin_a, dst=y)
return x + y
def equalize (img, mask=None, mode='cv', by_channels=True)
[view source on GitHub]¶
Apply histogram equalization to the input image.
This function enhances the contrast of the input image by equalizing its histogram. It supports both grayscale and color images, and can operate on individual channels or on the luminance channel of the image.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image. Can be grayscale (2D array) or RGB (3D array). |
mask | np.ndarray | None | Optional mask to apply the equalization selectively. If provided, must have the same shape as the input image. Default: None. |
mode | ImageMode | The backend to use for equalization. Can be either "cv" for OpenCV or "pil" for Pillow-style equalization. Default: "cv". |
by_channels | bool | If True, applies equalization to each channel independently. If False, converts the image to YCrCb color space and equalizes only the luminance channel. Only applicable to color images. Default: True. |
Returns:
Type | Description |
---|---|
np.ndarray | Equalized image. The output has the same dtype as the input. |
Exceptions:
Type | Description |
---|---|
ValueError | If the input image or mask have invalid shapes or types. |
Note
- If the input image is not uint8, it will be temporarily converted to uint8 for processing and then converted back to its original dtype.
- For color images, when by_channels=False, the image is converted to YCrCb color space, equalized on the Y channel, and then converted back to RGB.
- The function preserves the original number of channels in the image.
Examples:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> equalized = A.equalize(image, mode="cv", by_channels=True)
>>> assert equalized.shape == image.shape
>>> assert equalized.dtype == image.dtype
Source code in albumentations/augmentations/functional.py
@uint8_io
@preserve_channel_dim
def equalize(
img: np.ndarray,
mask: np.ndarray | None = None,
mode: Literal["cv", "pil"] = "cv",
by_channels: bool = True,
) -> np.ndarray:
"""Apply histogram equalization to the input image.
This function enhances the contrast of the input image by equalizing its histogram.
It supports both grayscale and color images, and can operate on individual channels
or on the luminance channel of the image.
Args:
img (np.ndarray): Input image. Can be grayscale (2D array) or RGB (3D array).
mask (np.ndarray | None): Optional mask to apply the equalization selectively.
If provided, must have the same shape as the input image. Default: None.
mode (ImageMode): The backend to use for equalization. Can be either "cv" for
OpenCV or "pil" for Pillow-style equalization. Default: "cv".
by_channels (bool): If True, applies equalization to each channel independently.
If False, converts the image to YCrCb color space and equalizes only the
luminance channel. Only applicable to color images. Default: True.
Returns:
np.ndarray: Equalized image. The output has the same dtype as the input.
Raises:
ValueError: If the input image or mask have invalid shapes or types.
Note:
- If the input image is not uint8, it will be temporarily converted to uint8
for processing and then converted back to its original dtype.
- For color images, when by_channels=False, the image is converted to YCrCb
color space, equalized on the Y channel, and then converted back to RGB.
- The function preserves the original number of channels in the image.
Example:
>>> import numpy as np
>>> import albumentations as A
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> equalized = A.equalize(image, mode="cv", by_channels=True)
>>> assert equalized.shape == image.shape
>>> assert equalized.dtype == image.dtype
"""
_check_preconditions(img, mask, by_channels)
function = _equalize_pil if mode == "pil" else _equalize_cv
if is_grayscale_image(img):
return function(img, _handle_mask(mask))
if not by_channels:
result_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
result_img[..., 0] = function(result_img[..., 0], _handle_mask(mask))
return cv2.cvtColor(result_img, cv2.COLOR_YCrCb2RGB)
result_img = np.empty_like(img)
for i in range(NUM_RGB_CHANNELS):
_mask = _handle_mask(mask, i)
result_img[..., i] = function(img[..., i], _mask)
return result_img
def fancy_pca (img, alpha_vector)
[view source on GitHub]¶
Perform 'Fancy PCA' augmentation on an image with any number of channels.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image |
alpha_vector | np.ndarray | Vector of scale factors for each principal component. Should have the same length as the number of channels in the image. |
Returns:
Type | Description |
---|---|
np.ndarray | Augmented image of the same shape, type, and range as the input. |
Image types: uint8, float32
Number of channels: Any
Note
- This function generalizes the Fancy PCA augmentation to work with any number of channels.
- It preserves the original range of the image ([0, 255] for uint8, [0, 1] for float32).
- For single-channel images, the augmentation is applied as a simple scaling of pixel intensity variation.
- For multi-channel images, PCA is performed on the entire image, treating each pixel as a point in N-dimensional space (where N is the number of channels).
- The augmentation preserves the correlation between channels while adding controlled noise.
- Computation time may increase significantly for images with a large number of channels.
Reference
Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). ImageNet classification with deep convolutional neural networks. In Advances in neural information processing systems (pp. 1097-1105).
Source code in albumentations/augmentations/functional.py
@float32_io
@clipped
@preserve_channel_dim
def fancy_pca(img: np.ndarray, alpha_vector: np.ndarray) -> np.ndarray:
"""Perform 'Fancy PCA' augmentation on an image with any number of channels.
Args:
img (np.ndarray): Input image
alpha_vector (np.ndarray): Vector of scale factors for each principal component.
Should have the same length as the number of channels in the image.
Returns:
np.ndarray: Augmented image of the same shape, type, and range as the input.
Image types:
uint8, float32
Number of channels:
Any
Note:
- This function generalizes the Fancy PCA augmentation to work with any number of channels.
- It preserves the original range of the image ([0, 255] for uint8, [0, 1] for float32).
- For single-channel images, the augmentation is applied as a simple scaling of pixel intensity variation.
- For multi-channel images, PCA is performed on the entire image, treating each pixel
as a point in N-dimensional space (where N is the number of channels).
- The augmentation preserves the correlation between channels while adding controlled noise.
- Computation time may increase significantly for images with a large number of channels.
Reference:
Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012).
ImageNet classification with deep convolutional neural networks.
In Advances in neural information processing systems (pp. 1097-1105).
"""
orig_shape = img.shape
num_channels = get_num_channels(img)
# Reshape image to 2D array of pixels
img_reshaped = img.reshape(-1, num_channels)
# Center the pixel values
img_mean = np.mean(img_reshaped, axis=0)
img_centered = img_reshaped - img_mean
if num_channels == 1:
# For grayscale images, apply a simple scaling
std_dev = np.std(img_centered)
noise = alpha_vector[0] * std_dev * img_centered
else:
# Compute covariance matrix
img_cov = np.cov(img_centered, rowvar=False)
# Compute eigenvectors & eigenvalues of the covariance matrix
eig_vals, eig_vecs = np.linalg.eigh(img_cov)
# Sort eigenvectors by eigenvalues in descending order
sort_perm = eig_vals[::-1].argsort()
eig_vals = eig_vals[sort_perm]
eig_vecs = eig_vecs[:, sort_perm]
# Create noise vector
noise = np.dot(
np.dot(eig_vecs, np.diag(alpha_vector * eig_vals)),
img_centered.T,
).T
# Add noise to the image
img_pca = img_reshaped + noise
# Reshape back to original shape
img_pca = img_pca.reshape(orig_shape)
# Clip values to [0, 1] range
return np.clip(img_pca, 0, 1, out=img_pca)
def generate_constant_noise (noise_type, shape, params, max_value, random_generator)
[view source on GitHub]¶
Generate one value per channel.
Source code in albumentations/augmentations/functional.py
def generate_constant_noise(
noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
shape: tuple[int, ...],
params: dict[str, Any],
max_value: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate one value per channel."""
num_channels = shape[-1] if len(shape) > MONO_CHANNEL_DIMENSIONS else 1
return sample_noise(
noise_type,
(num_channels,),
params,
max_value,
random_generator,
)
def generate_per_pixel_noise (noise_type, shape, params, max_value, random_generator)
[view source on GitHub]¶
Generate separate noise map for each channel.
Source code in albumentations/augmentations/functional.py
def generate_per_pixel_noise(
noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
shape: tuple[int, ...],
params: dict[str, Any],
max_value: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate separate noise map for each channel."""
return sample_noise(noise_type, shape, params, max_value, random_generator)
def generate_plasma_pattern (target_shape, roughness, random_generator)
[view source on GitHub]¶
Generate Plasma Fractal with consistent brightness.
Source code in albumentations/augmentations/functional.py
def generate_plasma_pattern(
target_shape: tuple[int, int],
roughness: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate Plasma Fractal with consistent brightness."""
def one_diamond_square_step(current_grid: np.ndarray, noise_scale: float) -> np.ndarray:
next_height = (current_grid.shape[0] - 1) * 2 + 1
next_width = (current_grid.shape[1] - 1) * 2 + 1
# Pre-allocate expanded grid
expanded_grid = np.zeros((next_height, next_width), dtype=np.float32)
# Generate all noise at once for both steps (already scaled by noise_scale)
all_noise = random_generator.uniform(-noise_scale, noise_scale, (next_height, next_width)).astype(np.float32)
# Copy existing points with noise
expanded_grid[::2, ::2] = current_grid + all_noise[::2, ::2]
# Diamond step - keep separate for natural look
diamond_interpolation = cv2.filter2D(expanded_grid, -1, DIAMOND_KERNEL, borderType=cv2.BORDER_CONSTANT)
diamond_mask = diamond_interpolation > 0
expanded_grid += (diamond_interpolation + all_noise) * diamond_mask
# Square step - keep separate for natural look
square_interpolation = cv2.filter2D(expanded_grid, -1, SQUARE_KERNEL, borderType=cv2.BORDER_CONSTANT)
square_mask = square_interpolation > 0
expanded_grid += (square_interpolation + all_noise) * square_mask
# Normalize after each step to prevent value drift
return cv2.normalize(expanded_grid, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
# Pre-compute noise scales
max_dimension = max(target_shape)
power_of_two_size = 2 ** np.ceil(np.log2(max_dimension - 1)) + 1
total_steps = int(np.log2(power_of_two_size - 1) - 1)
noise_scales = np.float32([roughness**i for i in range(total_steps)])
# Initialize with small random grid
plasma_grid = random_generator.uniform(-1, 1, (3, 3)).astype(np.float32)
# Recursively apply diamond-square steps
for noise_scale in noise_scales:
plasma_grid = one_diamond_square_step(plasma_grid, noise_scale)
return np.clip(
cv2.normalize(plasma_grid[: target_shape[0], : target_shape[1]], None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F),
0,
1,
)
def generate_random_values (channels, dtype, random_generator)
[view source on GitHub]¶
Generate random values for dropped pixels.
Parameters:
Name | Type | Description |
---|---|---|
channels | int | Number of channels in the image |
dtype | np.dtype | Data type of the image |
random_generator | np.random.Generator | Random number generator |
Returns:
Type | Description |
---|---|
np.ndarray | Array of random values |
Source code in albumentations/augmentations/functional.py
def generate_random_values(
channels: int,
dtype: np.dtype,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate random values for dropped pixels.
Args:
channels: Number of channels in the image
dtype: Data type of the image
random_generator: Random number generator
Returns:
Array of random values
"""
if dtype == np.uint8:
return random_generator.integers(
0,
int(MAX_VALUES_BY_DTYPE[dtype]),
size=channels,
dtype=dtype,
)
if dtype == np.float32:
return random_generator.uniform(0, 1, size=channels).astype(dtype)
raise ValueError(f"Unsupported dtype: {dtype}")
def generate_shared_noise (noise_type, shape, params, max_value, random_generator)
[view source on GitHub]¶
Generate one noise map and broadcast to all channels.
Parameters:
Name | Type | Description |
---|---|---|
noise_type | Literal['uniform', 'gaussian', 'laplace', 'beta'] | Type of noise distribution to use |
shape | tuple[int, ...] | Shape of the input image (H, W) or (H, W, C) |
params | dict[str, Any] | Parameters for the noise distribution |
max_value | float | Maximum value for the noise distribution |
random_generator | np.random.Generator | NumPy random generator instance |
Returns:
Type | Description |
---|---|
np.ndarray | Noise array of shape (H, W) or (H, W, C) where the same noise pattern is shared across all channels |
Source code in albumentations/augmentations/functional.py
def generate_shared_noise(
noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
shape: tuple[int, ...],
params: dict[str, Any],
max_value: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate one noise map and broadcast to all channels.
Args:
noise_type: Type of noise distribution to use
shape: Shape of the input image (H, W) or (H, W, C)
params: Parameters for the noise distribution
max_value: Maximum value for the noise distribution
random_generator: NumPy random generator instance
Returns:
Noise array of shape (H, W) or (H, W, C) where the same noise
pattern is shared across all channels
"""
# Generate noise for (H, W)
height, width = shape[:2]
noise_map = sample_noise(
noise_type,
(height, width),
params,
max_value,
random_generator,
)
# If input is multichannel, broadcast noise to all channels
if len(shape) > MONO_CHANNEL_DIMENSIONS:
return np.broadcast_to(noise_map[..., None], shape)
return noise_map
def generate_snow_textures (img_shape, random_generator)
[view source on GitHub]¶
Generate snow texture and sparkle mask.
Parameters:
Name | Type | Description |
---|---|---|
img_shape | tuple[int, int] | Image shape. |
random_generator | np.random.Generator | Random generator to use. |
Returns:
Type | Description |
---|---|
tuple[np.ndarray, np.ndarray] | Tuple of (snow_texture, sparkle_mask) arrays. |
Source code in albumentations/augmentations/functional.py
def generate_snow_textures(
img_shape: tuple[int, int],
random_generator: np.random.Generator,
) -> tuple[np.ndarray, np.ndarray]:
"""Generate snow texture and sparkle mask.
Args:
img_shape (tuple[int, int]): Image shape.
random_generator (np.random.Generator): Random generator to use.
Returns:
tuple[np.ndarray, np.ndarray]: Tuple of (snow_texture, sparkle_mask) arrays.
"""
# Generate base snow texture
snow_texture = random_generator.normal(size=img_shape[:2], loc=0.5, scale=0.3)
snow_texture = cv2.GaussianBlur(snow_texture, (0, 0), sigmaX=1, sigmaY=1)
# Generate sparkle mask
sparkle_mask = random_generator.random(img_shape[:2]) > 0.99
return snow_texture, sparkle_mask
def get_drop_mask (shape, per_channel, dropout_prob, random_generator)
[view source on GitHub]¶
Generate a boolean mask for pixel dropout.
Parameters:
Name | Type | Description |
---|---|---|
shape | tuple[int, ...] | Shape of the input array |
per_channel | bool | Whether to generate independent masks per channel |
dropout_prob | float | Probability of dropping a pixel |
random_generator | np.random.Generator | Random number generator |
Returns:
Type | Description |
---|---|
np.ndarray | Boolean mask matching input shape |
Source code in albumentations/augmentations/functional.py
def get_drop_mask(
shape: tuple[int, ...],
per_channel: bool,
dropout_prob: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate a boolean mask for pixel dropout.
Args:
shape: Shape of the input array
per_channel: Whether to generate independent masks per channel
dropout_prob: Probability of dropping a pixel
random_generator: Random number generator
Returns:
Boolean mask matching input shape
"""
if per_channel or len(shape) == 2:
return random_generator.choice(
[True, False],
shape,
p=[dropout_prob, 1 - dropout_prob],
)
# Generate 2D mask and expand to match channels
mask_2d = random_generator.choice(
[True, False],
shape[:2],
p=[dropout_prob, 1 - dropout_prob],
)
# If input is 2D, return 2D mask
if len(shape) == 2:
return mask_2d
# For 3D input, expand and repeat across channels
return np.repeat(mask_2d[..., None], shape[2], axis=2)
def get_fog_particle_radiuses (img_shape, num_particles, fog_intensity, random_generator)
[view source on GitHub]¶
Generate radiuses for fog particles.
Parameters:
Name | Type | Description |
---|---|---|
img_shape | tuple[int, int] | Image shape. |
num_particles | int | Number of fog particles. |
fog_intensity | float | Intensity of the fog effect, between 0 and 1. |
random_generator | np.random.Generator | Random generator to use. |
Returns:
Type | Description |
---|---|
list[int] | List of radiuses for each fog particle. |
Source code in albumentations/augmentations/functional.py
def get_fog_particle_radiuses(
img_shape: tuple[int, int],
num_particles: int,
fog_intensity: float,
random_generator: np.random.Generator,
) -> list[int]:
"""Generate radiuses for fog particles.
Args:
img_shape (tuple[int, int]): Image shape.
num_particles (int): Number of fog particles.
fog_intensity (float): Intensity of the fog effect, between 0 and 1.
random_generator (np.random.Generator): Random generator to use.
Returns:
list[int]: List of radiuses for each fog particle.
"""
height, width = img_shape[:2]
max_fog_radius = max(2, int(min(height, width) * 0.1 * fog_intensity))
min_radius = max(1, max_fog_radius // 2)
return [random_generator.integers(min_radius, max_fog_radius) for _ in range(num_particles)]
def get_histogram_bounds (hist, cutoff)
[view source on GitHub]¶
Find the low and high bounds of the histogram.
Source code in albumentations/augmentations/functional.py
def get_histogram_bounds(hist: np.ndarray, cutoff: float) -> tuple[int, int]:
"""Find the low and high bounds of the histogram."""
if not cutoff:
non_zero_intensities = np.nonzero(hist)[0]
if len(non_zero_intensities) == 0:
return 0, 0
return int(non_zero_intensities[0]), int(non_zero_intensities[-1])
total_pixels = float(hist.sum())
if total_pixels == 0:
return 0, 0
pixels_to_cut = total_pixels * cutoff / 100.0
# Special case for uniform 256-bin histogram
if len(hist) == 256 and np.all(hist == hist[0]):
min_intensity = int(len(hist) * cutoff / 100) # floor division
max_intensity = len(hist) - min_intensity - 1
return min_intensity, max_intensity
# Find minimum intensity
cumsum = 0.0
min_intensity = 0
for i in range(len(hist)):
cumsum += hist[i]
if cumsum >= pixels_to_cut: # Use >= for left bound
min_intensity = i + 1
break
min_intensity = min(min_intensity, len(hist) - 1)
# Find maximum intensity
cumsum = 0.0
max_intensity = len(hist) - 1
for i in range(len(hist) - 1, -1, -1):
cumsum += hist[i]
if cumsum >= pixels_to_cut: # Use >= for right bound
max_intensity = i
break
# Handle edge cases
if min_intensity > max_intensity:
mid_point = (len(hist) - 1) // 2
return mid_point, mid_point
return min_intensity, max_intensity
def get_mask_array (data)
[view source on GitHub]¶
Get mask array from input data if it exists.
def get_mud_params (liquid_layer, color, cutout_threshold, sigma, intensity, random_generator)
[view source on GitHub]¶
Generate mud effect parameters based on liquid layer.
Source code in albumentations/augmentations/functional.py
def get_mud_params(
liquid_layer: np.ndarray,
color: np.ndarray,
cutout_threshold: float,
sigma: float,
intensity: float,
random_generator: np.random.Generator,
) -> dict[str, Any]:
"""Generate mud effect parameters based on liquid layer."""
height, width = liquid_layer.shape
# Create initial mask (ensure we have some non-zero values)
mask = (liquid_layer > cutout_threshold).astype(np.float32)
if np.sum(mask) == 0: # If mask is all zeros
# Force minimum coverage of 10%
num_pixels = height * width
num_needed = max(1, int(0.1 * num_pixels)) # At least 1 pixel
flat_indices = random_generator.choice(num_pixels, num_needed, replace=False)
mask = np.zeros_like(liquid_layer, dtype=np.float32)
mask.flat[flat_indices] = 1.0
# Apply Gaussian blur if sigma > 0
if sigma > 0:
mask = cv2.GaussianBlur(
mask,
ksize=(0, 0),
sigmaX=sigma,
sigmaY=sigma,
borderType=cv2.BORDER_REPLICATE,
)
# Safe normalization (avoid division by zero)
mask_max = np.max(mask)
if mask_max > 0:
mask = mask / mask_max
else:
# If mask is somehow all zeros after blur, force some effect
mask[0, 0] = 1.0
# Scale by intensity directly (no minimum)
mask = mask * intensity
# Create mud effect array
mud = np.zeros((height, width, 3), dtype=np.float32)
# Apply color directly - the intensity scaling is already handled
for i in range(3):
mud[..., i] = mask * color[i]
# Create complementary non-mud array
non_mud = np.ones_like(mud)
for i in range(3):
if color[i] > 0:
non_mud[..., i] = np.clip((color[i] - mud[..., i]) / color[i], 0, 1)
else:
non_mud[..., i] = 1.0 - mask
return {
"mud": mud.astype(np.float32),
"non_mud": non_mud.astype(np.float32),
}
def get_normalizer (method)
[view source on GitHub]¶
Get stain normalizer based on method.
def get_rain_params (liquid_layer, color, intensity)
[view source on GitHub]¶
Generate parameters for rain effect.
Source code in albumentations/augmentations/functional.py
def get_rain_params(
liquid_layer: np.ndarray,
color: np.ndarray,
intensity: float,
) -> dict[str, Any]:
"""Generate parameters for rain effect."""
liquid_layer = clip(liquid_layer * 255, np.uint8, inplace=False)
# Generate distance transform with more defined edges
dist = 255 - cv2.Canny(liquid_layer, 50, 150)
dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
_, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
# Use separate blur operations for better drop formation
dist = cv2.GaussianBlur(
dist,
ksize=(3, 3),
sigmaX=1, # Add slight sigma for smoother drops
sigmaY=1,
borderType=cv2.BORDER_REPLICATE,
)
dist = clip(dist, np.uint8, inplace=True)
# Enhance contrast in the distance map
dist = equalize(dist)
# Modified kernel for more natural drop shapes
ker = np.array(
[
[-2, -1, 0],
[-1, 1, 1],
[0, 1, 2],
],
dtype=np.float32,
)
# Apply convolution with better precision
dist = convolve(dist, ker)
# Final blur with larger kernel for smoother drops
dist = cv2.GaussianBlur(
dist,
ksize=(5, 5), # Increased kernel size
sigmaX=1.5, # Adjusted sigma
sigmaY=1.5,
borderType=cv2.BORDER_REPLICATE,
).astype(np.float32)
# Calculate final rain mask with better blending
m = liquid_layer.astype(np.float32) * dist
# Normalize with better handling of edge cases
m_max = np.max(m, axis=(0, 1))
if m_max > 0:
m *= 1 / m_max
else:
m = np.zeros_like(m)
# Apply color with adjusted intensity for more natural look
drops = m[:, :, None] * color * (intensity * 0.9) # Slightly reduced intensity
return {
"drops": drops,
}
def get_safe_brightness_contrast_params (alpha, beta, max_value)
[view source on GitHub]¶
Calculate safe alpha and beta values to prevent overflow/underflow.
For any pixel value x, we want: 0 <= alpha * x + beta <= max_value
Parameters:
Name | Type | Description |
---|---|---|
alpha | float | Contrast factor (1 means no change) |
beta | float | Brightness offset |
max_value | float | Maximum allowed value (255 for uint8, 1 for float32) |
Returns:
Type | Description |
---|---|
tuple[float, float] | Safe (alpha, beta) values that prevent overflow/underflow |
Source code in albumentations/augmentations/functional.py
def get_safe_brightness_contrast_params(
alpha: float,
beta: float,
max_value: float,
) -> tuple[float, float]:
"""Calculate safe alpha and beta values to prevent overflow/underflow.
For any pixel value x, we want: 0 <= alpha * x + beta <= max_value
Args:
alpha: Contrast factor (1 means no change)
beta: Brightness offset
max_value: Maximum allowed value (255 for uint8, 1 for float32)
Returns:
tuple[float, float]: Safe (alpha, beta) values that prevent overflow/underflow
"""
if alpha > 0:
# For x = max_value: alpha * max_value + beta <= max_value
# For x = 0: beta >= 0
safe_beta = np.clip(beta, 0, max_value)
# From alpha * max_value + safe_beta <= max_value
safe_alpha = min(alpha, (max_value - safe_beta) / max_value)
else:
# For x = 0: beta <= max_value
# For x = max_value: alpha * max_value + beta >= 0
safe_beta = min(beta, max_value)
# From alpha * max_value + safe_beta >= 0
safe_alpha = max(alpha, -safe_beta / max_value)
return safe_alpha, safe_beta
def get_tissue_mask (img, threshold=0.85)
[view source on GitHub]¶
Get binary mask of tissue regions based on luminosity.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | RGB image in float32 format, range [0, 1] |
threshold | float | Luminosity threshold. Pixels with luminosity below this value are considered tissue. Range: 0 to 1. Default: 0.85 |
Returns:
Type | Description |
---|---|
np.ndarray | Binary mask where True indicates tissue regions |
Source code in albumentations/augmentations/functional.py
def get_tissue_mask(img: np.ndarray, threshold: float = 0.85) -> np.ndarray:
"""Get binary mask of tissue regions based on luminosity.
Args:
img: RGB image in float32 format, range [0, 1]
threshold: Luminosity threshold. Pixels with luminosity below this value
are considered tissue. Range: 0 to 1. Default: 0.85
Returns:
Binary mask where True indicates tissue regions
"""
# Convert to grayscale using RGB weights: R*0.299 + G*0.587 + B*0.114
luminosity = img[..., 0] * 0.299 + img[..., 1] * 0.587 + img[..., 2] * 0.114
# Tissue is darker, so we want pixels below threshold
mask = luminosity < threshold
return mask.reshape(-1)
def grayscale_to_multichannel (grayscale_image, num_output_channels=3)
[view source on GitHub]¶
Convert a grayscale image to a multi-channel image.
This function takes a 2D grayscale image or a 3D image with a single channel and converts it to a multi-channel image by repeating the grayscale data across the specified number of channels.
Parameters:
Name | Type | Description |
---|---|---|
grayscale_image | np.ndarray | Input grayscale image. Can be 2D (height, width) or 3D (height, width, 1). |
num_output_channels | int | Number of channels in the output image. Defaults to 3. |
Returns:
Type | Description |
---|---|
np.ndarray | Multi-channel image with shape (height, width, num_channels) |
Source code in albumentations/augmentations/functional.py
def grayscale_to_multichannel(
grayscale_image: np.ndarray,
num_output_channels: int = 3,
) -> np.ndarray:
"""Convert a grayscale image to a multi-channel image.
This function takes a 2D grayscale image or a 3D image with a single channel
and converts it to a multi-channel image by repeating the grayscale data
across the specified number of channels.
Args:
grayscale_image (np.ndarray): Input grayscale image. Can be 2D (height, width)
or 3D (height, width, 1).
num_output_channels (int, optional): Number of channels in the output image. Defaults to 3.
Returns:
np.ndarray: Multi-channel image with shape (height, width, num_channels)
"""
# If output should be single channel, just squeeze and return
if num_output_channels == 1:
return grayscale_image
# For multi-channel output, squeeze and stack
squeezed = np.squeeze(grayscale_image)
return cv2.merge([squeezed] * num_output_channels)
def image_compression (img, quality, image_type)
[view source on GitHub]¶
Apply compression to image.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image |
quality | int | Compression quality (0-100) |
image_type | Literal['.jpg', '.webp'] | Type of compression ('.jpg' or '.webp') |
Returns:
Type | Description |
---|---|
np.ndarray | Compressed image with same number of channels as input |
Source code in albumentations/augmentations/functional.py
@uint8_io
@preserve_channel_dim
def image_compression(
img: np.ndarray,
quality: int,
image_type: Literal[".jpg", ".webp"],
) -> np.ndarray:
"""Apply compression to image.
Args:
img: Input image
quality: Compression quality (0-100)
image_type: Type of compression ('.jpg' or '.webp')
Returns:
Compressed image with same number of channels as input
"""
quality_flag = cv2.IMWRITE_JPEG_QUALITY if image_type == ".jpg" else cv2.IMWRITE_WEBP_QUALITY
num_channels = get_num_channels(img)
if num_channels == 1:
# For grayscale, ensure we read back as single channel
_, encoded_img = cv2.imencode(image_type, img, (int(quality_flag), quality))
decoded = cv2.imdecode(encoded_img, cv2.IMREAD_GRAYSCALE)
return decoded[..., np.newaxis] # Add channel dimension back
if num_channels == NUM_RGB_CHANNELS:
# Standard RGB image
_, encoded_img = cv2.imencode(image_type, img, (int(quality_flag), quality))
return cv2.imdecode(encoded_img, cv2.IMREAD_UNCHANGED)
# For 2,4 or more channels, we need to handle alpha/extra channels separately
if num_channels == 2:
# For 2 channels, pad to 3 channels and take only first 2 after compression
padded = np.pad(img, ((0, 0), (0, 0), (0, 1)), mode="constant")
_, encoded_bgr = cv2.imencode(image_type, padded, (int(quality_flag), quality))
decoded_bgr = cv2.imdecode(encoded_bgr, cv2.IMREAD_UNCHANGED)
return decoded_bgr[..., :2]
# Process first 3 channels together
bgr = img[..., :NUM_RGB_CHANNELS]
_, encoded_bgr = cv2.imencode(image_type, bgr, (int(quality_flag), quality))
decoded_bgr = cv2.imdecode(encoded_bgr, cv2.IMREAD_UNCHANGED)
if num_channels > NUM_RGB_CHANNELS:
# Process additional channels one by one
extra_channels = []
for i in range(NUM_RGB_CHANNELS, num_channels):
channel = img[..., i]
_, encoded = cv2.imencode(image_type, channel, (int(quality_flag), quality))
decoded = cv2.imdecode(encoded, cv2.IMREAD_GRAYSCALE)
if len(decoded.shape) == 2:
decoded = decoded[..., np.newaxis]
extra_channels.append(decoded)
# Combine BGR with extra channels
return np.dstack([decoded_bgr, *extra_channels])
return decoded_bgr
def iso_noise (image, color_shift, intensity, random_generator)
[view source on GitHub]¶
Apply poisson noise to an image to simulate camera sensor noise.
Parameters:
Name | Type | Description |
---|---|---|
image | np.ndarray | Input image. Currently, only RGB images are supported. |
color_shift | float | The amount of color shift to apply. |
intensity | float | Multiplication factor for noise values. Values of ~0.5 produce a noticeable, yet acceptable level of noise. |
random_generator | np.random.Generator | If specified, this will be random generator used for noise generation. |
Returns:
Type | Description |
---|---|
np.ndarray | The noised image. |
Image types: uint8, float32
Number of channels: 3
Source code in albumentations/augmentations/functional.py
@float32_io
@clipped
def iso_noise(
image: np.ndarray,
color_shift: float,
intensity: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Apply poisson noise to an image to simulate camera sensor noise.
Args:
image (np.ndarray): Input image. Currently, only RGB images are supported.
color_shift (float): The amount of color shift to apply.
intensity (float): Multiplication factor for noise values. Values of ~0.5 produce a noticeable,
yet acceptable level of noise.
random_generator (np.random.Generator): If specified, this will be random generator used
for noise generation.
Returns:
np.ndarray: The noised image.
Image types:
uint8, float32
Number of channels:
3
"""
hls = cv2.cvtColor(image, cv2.COLOR_RGB2HLS)
_, stddev = cv2.meanStdDev(hls)
luminance_noise = random_generator.poisson(
stddev[1] * intensity,
size=hls.shape[:2],
)
color_noise = random_generator.normal(
0,
color_shift * intensity,
size=hls.shape[:2],
)
hls[..., 0] += color_noise
hls[..., 1] = add_array(
hls[..., 1],
luminance_noise * intensity * (1.0 - hls[..., 1]),
)
noised_hls = cv2.cvtColor(hls, cv2.COLOR_HLS2RGB)
return np.clip(noised_hls, 0, 1, out=noised_hls) # Ensure output is in [0, 1] range
def move_tone_curve (img, low_y, high_y)
[view source on GitHub]¶
Rescales the relationship between bright and dark areas of the image by manipulating its tone curve.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | np.ndarray. Any number of channels |
low_y | float | np.ndarray | per-channel or single y-position of a Bezier control point used to adjust the tone curve, must be in range [0, 1] |
high_y | float | np.ndarray | per-channel or single y-position of a Bezier control point used to adjust image tone curve, must be in range [0, 1] |
Source code in albumentations/augmentations/functional.py
@uint8_io
def move_tone_curve(
img: np.ndarray,
low_y: float | np.ndarray,
high_y: float | np.ndarray,
) -> np.ndarray:
"""Rescales the relationship between bright and dark areas of the image by manipulating its tone curve.
Args:
img: np.ndarray. Any number of channels
low_y: per-channel or single y-position of a Bezier control point used
to adjust the tone curve, must be in range [0, 1]
high_y: per-channel or single y-position of a Bezier control point used
to adjust image tone curve, must be in range [0, 1]
"""
t = np.linspace(0.0, 1.0, 256)
def evaluate_bez(
t: np.ndarray,
low_y: float | np.ndarray,
high_y: float | np.ndarray,
) -> np.ndarray:
one_minus_t = 1 - t
return (3 * one_minus_t**2 * t * low_y + 3 * one_minus_t * t**2 * high_y + t**3) * 255
num_channels = get_num_channels(img)
if np.isscalar(low_y) and np.isscalar(high_y):
lut = clip(np.rint(evaluate_bez(t, low_y, high_y)), np.uint8, inplace=False)
return sz_lut(img, lut, inplace=False)
if isinstance(low_y, np.ndarray) and isinstance(high_y, np.ndarray):
luts = clip(
np.rint(evaluate_bez(t[:, np.newaxis], low_y, high_y).T),
np.uint8,
inplace=False,
)
return cv2.merge(
[sz_lut(img[:, :, i], np.ascontiguousarray(luts[i]), inplace=False) for i in range(num_channels)],
)
raise TypeError(
f"low_y and high_y must both be of type float or np.ndarray. Got {type(low_y)} and {type(high_y)}",
)
def order_stains_combined (stain_colors)
[view source on GitHub]¶
Order stains using a combination of methods.
This combines both angular information and spectral characteristics for more robust identification.
Source code in albumentations/augmentations/functional.py
def order_stains_combined(stain_colors: np.ndarray) -> tuple[int, int]:
"""Order stains using a combination of methods.
This combines both angular information and spectral characteristics
for more robust identification.
"""
# Normalize stain vectors
stain_colors = normalize_vectors(stain_colors)
# Calculate angles (Macenko)
angles = np.mod(np.arctan2(stain_colors[:, 1], stain_colors[:, 0]), np.pi)
# Calculate spectral ratios (Ruifrok)
blue_ratio = stain_colors[:, 2] / (np.sum(stain_colors, axis=1) + 1e-6)
red_ratio = stain_colors[:, 0] / (np.sum(stain_colors, axis=1) + 1e-6)
# Combine scores
# High angle and high blue ratio indicates Hematoxylin
# Low angle and high red ratio indicates Eosin
scores = angles * blue_ratio - red_ratio
hematoxylin_idx = np.argmax(scores)
eosin_idx = 1 - hematoxylin_idx
return hematoxylin_idx, eosin_idx
def pixel_dropout (image, drop_mask, drop_values)
[view source on GitHub]¶
Apply pixel dropout to an image.
Parameters:
Name | Type | Description |
---|---|---|
image | np.ndarray | Input image |
drop_mask | np.ndarray | Boolean mask of same shape as image indicating pixels to drop |
drop_values | np.ndarray | Values to use for dropped pixels, same shape as image |
Returns:
Type | Description |
---|---|
np.ndarray | Image with pixels dropped according to mask |
Source code in albumentations/augmentations/functional.py
@preserve_channel_dim
def pixel_dropout(
image: np.ndarray,
drop_mask: np.ndarray,
drop_values: np.ndarray,
) -> np.ndarray:
"""Apply pixel dropout to an image.
Args:
image: Input image
drop_mask: Boolean mask of same shape as image indicating pixels to drop
drop_values: Values to use for dropped pixels, same shape as image
Returns:
Image with pixels dropped according to mask
"""
return np.where(drop_mask, drop_values, image)
def posterize (img, bits)
[view source on GitHub]¶
Reduce the number of bits for each color channel by keeping only the highest N bits.
This transform performs bit-depth reduction by masking out lower bits, effectively reducing the number of possible values per channel. This creates a posterization effect where similar colors are merged together.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image. Can be single or multi-channel. |
bits | Literal[1, 2, 3, 4, 5, 6, 7] | list[Literal[1, 2, 3, 4, 5, 6, 7]] | Number of high bits to keep. Must be in range [1, 7]. Can be either: - A single value to apply the same bit reduction to all channels - A list of values to apply different bit reduction per channel. Length of list must match number of channels in image. |
Returns:
Type | Description |
---|---|
np.ndarray | Image with reduced bit depth. Has same shape and dtype as input. |
Note
- The transform keeps the N highest bits and sets all other bits to 0
- For example, if bits=3:
- Original value: 11010110 (214)
- Keep 3 bits: 11000000 (192)
- The number of unique colors per channel will be 2^bits
- Higher bits values = more colors = more subtle effect
- Lower bits values = fewer colors = more dramatic posterization
Examples:
>>> import numpy as np
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> # Same posterization for all channels
>>> result = posterize(image, bits=3)
>>> # Different posterization per channel
>>> result = posterize(image, bits=[3, 4, 5]) # RGB channels
Source code in albumentations/augmentations/functional.py
@uint8_io
@clipped
def posterize(img: np.ndarray, bits: Literal[1, 2, 3, 4, 5, 6, 7] | list[Literal[1, 2, 3, 4, 5, 6, 7]]) -> np.ndarray:
"""Reduce the number of bits for each color channel by keeping only the highest N bits.
This transform performs bit-depth reduction by masking out lower bits, effectively
reducing the number of possible values per channel. This creates a posterization
effect where similar colors are merged together.
Args:
img: Input image. Can be single or multi-channel.
bits: Number of high bits to keep. Must be in range [1, 7].
Can be either:
- A single value to apply the same bit reduction to all channels
- A list of values to apply different bit reduction per channel.
Length of list must match number of channels in image.
Returns:
np.ndarray: Image with reduced bit depth. Has same shape and dtype as input.
Note:
- The transform keeps the N highest bits and sets all other bits to 0
- For example, if bits=3:
- Original value: 11010110 (214)
- Keep 3 bits: 11000000 (192)
- The number of unique colors per channel will be 2^bits
- Higher bits values = more colors = more subtle effect
- Lower bits values = fewer colors = more dramatic posterization
Examples:
>>> import numpy as np
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> # Same posterization for all channels
>>> result = posterize(image, bits=3)
>>> # Different posterization per channel
>>> result = posterize(image, bits=[3, 4, 5]) # RGB channels
"""
bits_array = np.uint8(bits)
if not bits_array.shape or len(bits_array) == 1:
lut = np.arange(0, 256, dtype=np.uint8)
mask = ~np.uint8(2 ** (8 - bits_array) - 1)
lut &= mask
return sz_lut(img, lut, inplace=False)
result_img = np.empty_like(img)
for i, channel_bits in enumerate(bits_array):
lut = np.arange(0, 256, dtype=np.uint8)
mask = ~np.uint8(2 ** (8 - channel_bits) - 1)
lut &= mask
result_img[..., i] = sz_lut(img[..., i], lut, inplace=True)
return result_img
def prepare_drop_values (array, value, random_generator)
[view source on GitHub]¶
Prepare values to fill dropped pixels.
Parameters:
Name | Type | Description |
---|---|---|
array | np.ndarray | Input array to determine shape and dtype |
value | float | Sequence[float] | np.ndarray | None | User-specified drop values or None for random |
random_generator | np.random.Generator | Random number generator |
Returns:
Type | Description |
---|---|
np.ndarray | Array of values matching input shape |
Source code in albumentations/augmentations/functional.py
def prepare_drop_values(
array: np.ndarray,
value: float | Sequence[float] | np.ndarray | None,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Prepare values to fill dropped pixels.
Args:
array: Input array to determine shape and dtype
value: User-specified drop values or None for random
random_generator: Random number generator
Returns:
Array of values matching input shape
"""
if value is None:
channels = get_num_channels(array)
values = generate_random_values(channels, array.dtype, random_generator)
elif isinstance(value, (int, float)):
return np.full(array.shape, value, dtype=array.dtype)
else:
values = np.array(value, dtype=array.dtype).reshape(-1)
# For monochannel input, return single value
if array.ndim == 2:
return np.full(array.shape, values[0], dtype=array.dtype)
# For multichannel input, broadcast values to full shape
return np.full(array.shape[:2] + (len(values),), values, dtype=array.dtype)
def sample_beta (size, params, random_generator)
[view source on GitHub]¶
Sample from Beta distribution.
The Beta distribution is bounded by [0, 1] and then scaled and shifted to [-scale, scale]. Alpha and beta parameters control the shape of the distribution.
Source code in albumentations/augmentations/functional.py
def sample_beta(
size: tuple[int, ...],
params: dict[str, Any],
random_generator: np.random.Generator,
) -> np.ndarray:
"""Sample from Beta distribution.
The Beta distribution is bounded by [0, 1] and then scaled and shifted to [-scale, scale].
Alpha and beta parameters control the shape of the distribution.
"""
alpha = random_generator.uniform(*params["alpha_range"])
beta = random_generator.uniform(*params["beta_range"])
scale = random_generator.uniform(*params["scale_range"])
# Sample from Beta[0,1] and transform to [-scale,scale]
samples = random_generator.beta(alpha, beta, size=size)
return (2 * samples - 1) * scale
def sample_gaussian (size, params, random_generator)
[view source on GitHub]¶
Sample from Gaussian distribution.
Source code in albumentations/augmentations/functional.py
def sample_gaussian(
size: tuple[int, ...],
params: dict[str, Any],
random_generator: np.random.Generator,
) -> np.ndarray:
"""Sample from Gaussian distribution."""
mean = (
params["mean_range"][0]
if params["mean_range"][0] == params["mean_range"][1]
else random_generator.uniform(*params["mean_range"])
)
std = (
params["std_range"][0]
if params["std_range"][0] == params["std_range"][1]
else random_generator.uniform(*params["std_range"])
)
num_channels = size[2] if len(size) > MONO_CHANNEL_DIMENSIONS else 1
mean_vector = mean * np.ones(shape=(num_channels,), dtype=np.float32)
std_dev_vector = std * np.ones(shape=(num_channels,), dtype=np.float32)
gaussian_sampled_arr = np.zeros(shape=size)
cv2.randn(dst=gaussian_sampled_arr, mean=mean_vector, stddev=std_dev_vector)
return gaussian_sampled_arr.astype(np.float32)
def sample_laplace (size, params, random_generator)
[view source on GitHub]¶
Sample from Laplace distribution.
The Laplace distribution is also known as the double exponential distribution. It has heavier tails than the Gaussian distribution.
Source code in albumentations/augmentations/functional.py
def sample_laplace(
size: tuple[int, ...],
params: dict[str, Any],
random_generator: np.random.Generator,
) -> np.ndarray:
"""Sample from Laplace distribution.
The Laplace distribution is also known as the double exponential distribution.
It has heavier tails than the Gaussian distribution.
"""
loc = random_generator.uniform(*params["mean_range"])
scale = random_generator.uniform(*params["scale_range"])
return random_generator.laplace(loc=loc, scale=scale, size=size)
def sample_noise (noise_type, size, params, max_value, random_generator)
[view source on GitHub]¶
Sample from specific noise distribution.
Source code in albumentations/augmentations/functional.py
def sample_noise(
noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
size: tuple[int, ...],
params: dict[str, Any],
max_value: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Sample from specific noise distribution."""
if noise_type == "uniform":
return sample_uniform(size, params, random_generator) * max_value
if noise_type == "gaussian":
return sample_gaussian(size, params, random_generator) * max_value
if noise_type == "laplace":
return sample_laplace(size, params, random_generator) * max_value
if noise_type == "beta":
return sample_beta(size, params, random_generator) * max_value
raise ValueError(f"Unknown noise type: {noise_type}")
def sample_uniform (size, params, random_generator)
[view source on GitHub]¶
Sample from uniform distribution.
Parameters:
Name | Type | Description |
---|---|---|
size | tuple[int, ...] | Output shape. If length is 1, generates constant noise per channel. |
params | dict[str, Any] | Must contain 'ranges' key with list of (min, max) tuples. If only one range is provided, it will be used for all channels. |
random_generator | np.random.Generator | NumPy random generator instance |
Returns:
Type | Description |
---|---|
np.ndarray | float | Noise array of specified size. For single-channel constant mode, returns scalar instead of array with shape (1,). |
Source code in albumentations/augmentations/functional.py
def sample_uniform(
size: tuple[int, ...],
params: dict[str, Any],
random_generator: np.random.Generator,
) -> np.ndarray | float:
"""Sample from uniform distribution.
Args:
size: Output shape. If length is 1, generates constant noise per channel.
params: Must contain 'ranges' key with list of (min, max) tuples.
If only one range is provided, it will be used for all channels.
random_generator: NumPy random generator instance
Returns:
Noise array of specified size. For single-channel constant mode,
returns scalar instead of array with shape (1,).
"""
if len(size) == 1: # constant mode
ranges = params["ranges"]
num_channels = size[0]
if len(ranges) == 1:
ranges = ranges * num_channels
elif len(ranges) < num_channels:
raise ValueError(
f"Not enough ranges provided. Expected {num_channels}, got {len(ranges)}",
)
return np.array(
[random_generator.uniform(low, high) for low, high in ranges[:num_channels]],
)
# use first range for spatial noise
low, high = params["ranges"][0]
return random_generator.uniform(low, high, size=size)
def sharpen_gaussian (img, alpha, kernel_size, sigma)
[view source on GitHub]¶
Sharpen image using Gaussian blur.
Source code in albumentations/augmentations/functional.py
@clipped
@preserve_channel_dim
def sharpen_gaussian(
img: np.ndarray,
alpha: float,
kernel_size: int,
sigma: float,
) -> np.ndarray:
"""Sharpen image using Gaussian blur."""
blurred = cv2.GaussianBlur(
img,
ksize=(kernel_size, kernel_size),
sigmaX=sigma,
sigmaY=sigma,
)
# Unsharp mask formula: original + alpha * (original - blurred)
# This is equivalent to: original * (1 + alpha) - alpha * blurred
return img + alpha * (img - blurred)
def shot_noise (img, scale, random_generator)
[view source on GitHub]¶
Apply shot noise to the image by simulating photon counting in linear light space.
This function simulates photon shot noise, which occurs due to the quantum nature of light. The process: 1. Converts image to linear light space (removes gamma correction) 2. Scales pixel values to represent expected photon counts 3. Samples actual photon counts from Poisson distribution 4. Converts back to display space (reapplies gamma)
The simulation is performed in linear light space because photon shot noise is a physical process that occurs before gamma correction is applied by cameras/displays.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image in range [0, 1]. Can be single or multi-channel. |
scale | float | Reciprocal of the number of photons (noise intensity). - Larger values = fewer photons = more noise - Smaller values = more photons = less noise For example: - scale = 0.1 simulates ~100 photons per unit intensity - scale = 10.0 simulates ~0.1 photons per unit intensity |
random_generator | np.random.Generator | NumPy random generator for Poisson sampling |
Returns:
Type | Description |
---|---|
Image with shot noise applied, same shape and range [0, 1] as input. The noise characteristics will follow Poisson statistics in linear space |
|
Note
- Uses gamma value of 2.2 for linear/display space conversion
- Adds small constant (1e-6) to avoid issues with zero values
- Clips final values to [0, 1] range
- Operates on the image in-place for memory efficiency
- Preserves float32 precision throughout calculations
Source code in albumentations/augmentations/functional.py
@preserve_channel_dim
@float32_io
def shot_noise(
img: np.ndarray,
scale: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Apply shot noise to the image by simulating photon counting in linear light space.
This function simulates photon shot noise, which occurs due to the quantum nature of light.
The process:
1. Converts image to linear light space (removes gamma correction)
2. Scales pixel values to represent expected photon counts
3. Samples actual photon counts from Poisson distribution
4. Converts back to display space (reapplies gamma)
The simulation is performed in linear light space because photon shot noise is a physical
process that occurs before gamma correction is applied by cameras/displays.
Args:
img: Input image in range [0, 1]. Can be single or multi-channel.
scale: Reciprocal of the number of photons (noise intensity).
- Larger values = fewer photons = more noise
- Smaller values = more photons = less noise
For example:
- scale = 0.1 simulates ~100 photons per unit intensity
- scale = 10.0 simulates ~0.1 photons per unit intensity
random_generator: NumPy random generator for Poisson sampling
Returns:
Image with shot noise applied, same shape and range [0, 1] as input.
The noise characteristics will follow Poisson statistics in linear space:
- Variance equals mean in linear space
- More noise in brighter regions (but less relative noise)
- Less noise in darker regions (but more relative noise)
Note:
- Uses gamma value of 2.2 for linear/display space conversion
- Adds small constant (1e-6) to avoid issues with zero values
- Clips final values to [0, 1] range
- Operates on the image in-place for memory efficiency
- Preserves float32 precision throughout calculations
References:
- https://en.wikipedia.org/wiki/Shot_noise
- https://en.wikipedia.org/wiki/Gamma_correction
"""
# Apply inverse gamma correction to work in linear space
img_linear = cv2.pow(img, 2.2)
# Scale image values and add small constant to avoid zero values
scaled_img = (img_linear + scale * 1e-6) / scale
# Generate Poisson noise
noisy_img = multiply_by_constant(
random_generator.poisson(scaled_img).astype(np.float32),
scale,
inplace=True,
)
# Scale back and apply gamma correction
return power(np.clip(noisy_img, 0, 1, out=noisy_img), 1 / 2.2)
def slic (image, n_segments, compactness=10.0, max_iterations=10)
[view source on GitHub]¶
Simple Linear Iterative Clustering (SLIC) superpixel segmentation using OpenCV and NumPy.
Parameters:
Name | Type | Description |
---|---|---|
image | np.ndarray | Input image (2D or 3D numpy array). |
n_segments | int | Approximate number of superpixels to generate. |
compactness | float | Balance between color proximity and space proximity. |
max_iterations | int | Maximum number of iterations for k-means. |
Returns:
Type | Description |
---|---|
np.ndarray | Segmentation mask where each superpixel has a unique label. |
Source code in albumentations/augmentations/functional.py
def slic(
image: np.ndarray,
n_segments: int,
compactness: float = 10.0,
max_iterations: int = 10,
) -> np.ndarray:
"""Simple Linear Iterative Clustering (SLIC) superpixel segmentation using OpenCV and NumPy.
Args:
image (np.ndarray): Input image (2D or 3D numpy array).
n_segments (int): Approximate number of superpixels to generate.
compactness (float): Balance between color proximity and space proximity.
max_iterations (int): Maximum number of iterations for k-means.
Returns:
np.ndarray: Segmentation mask where each superpixel has a unique label.
"""
if image.ndim == MONO_CHANNEL_DIMENSIONS:
image = image[..., np.newaxis]
height, width = image.shape[:2]
num_pixels = height * width
# Normalize image to [0, 1] range
image_normalized = image.astype(np.float32) / np.max(image + 1e-6)
# Initialize cluster centers
grid_step = int((num_pixels / n_segments) ** 0.5)
x_range = np.arange(grid_step // 2, width, grid_step)
y_range = np.arange(grid_step // 2, height, grid_step)
centers = np.array(
[(x, y) for y in y_range for x in x_range if x < width and y < height],
)
# Initialize labels and distances
labels = -1 * np.ones((height, width), dtype=np.int32)
distances = np.full((height, width), np.inf)
for _ in range(max_iterations):
for i, center in enumerate(centers):
y, x = int(center[1]), int(center[0])
# Define the neighborhood
y_low, y_high = max(0, y - grid_step), min(height, y + grid_step + 1)
x_low, x_high = max(0, x - grid_step), min(width, x + grid_step + 1)
# Compute distances
crop = image_normalized[y_low:y_high, x_low:x_high]
color_diff = crop - image_normalized[y, x]
color_distance = np.sum(color_diff**2, axis=-1)
yy, xx = np.ogrid[y_low:y_high, x_low:x_high]
spatial_distance = ((yy - y) ** 2 + (xx - x) ** 2) / (grid_step**2)
distance = color_distance + compactness * spatial_distance
mask = distance < distances[y_low:y_high, x_low:x_high]
distances[y_low:y_high, x_low:x_high][mask] = distance[mask]
labels[y_low:y_high, x_low:x_high][mask] = i
# Update centers
for i in range(len(centers)):
mask = labels == i
if np.any(mask):
centers[i] = np.mean(np.argwhere(mask), axis=0)[::-1]
return labels
def solarize (img, threshold)
[view source on GitHub]¶
Invert all pixel values above a threshold.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | The image to solarize. Can be uint8 or float32. |
threshold | float | Normalized threshold value in range [0, 1]. For uint8 images: pixels above threshold * 255 are inverted For float32 images: pixels above threshold are inverted |
Returns:
Type | Description |
---|---|
np.ndarray | Solarized image. |
Note
The threshold is normalized to [0, 1] range for both uint8 and float32 images. For uint8 images, the threshold is internally scaled by 255.
Source code in albumentations/augmentations/functional.py
@clipped
def solarize(img: np.ndarray, threshold: float) -> np.ndarray:
"""Invert all pixel values above a threshold.
Args:
img: The image to solarize. Can be uint8 or float32.
threshold: Normalized threshold value in range [0, 1].
For uint8 images: pixels above threshold * 255 are inverted
For float32 images: pixels above threshold are inverted
Returns:
Solarized image.
Note:
The threshold is normalized to [0, 1] range for both uint8 and float32 images.
For uint8 images, the threshold is internally scaled by 255.
"""
dtype = img.dtype
max_val = MAX_VALUES_BY_DTYPE[dtype]
if dtype == np.uint8:
lut = [(max_val - i if i >= threshold * max_val else i) for i in range(int(max_val) + 1)]
prev_shape = img.shape
img = sz_lut(img, np.array(lut, dtype=dtype), inplace=False)
return np.expand_dims(img, -1) if len(prev_shape) != img.ndim else img
img = img.copy()
cond = img >= threshold
img[cond] = max_val - img[cond]
return img
def to_gray_average (img)
[view source on GitHub]¶
Convert an image to grayscale using the average method.
This function computes the arithmetic mean across all channels for each pixel, resulting in a grayscale representation of the image.
Key aspects of this method: 1. It treats all channels equally, regardless of their perceptual importance. 2. Works with any number of channels, making it versatile for various image types. 3. Simple and fast to compute, but may not accurately represent perceived brightness. 4. For RGB images, the formula is: Gray = (R + G + B) / 3
Note: This method may produce different results compared to weighted methods (like RGB weighted average) which account for human perception of color brightness. It may also produce unexpected results for images with alpha channels or non-color data in additional channels.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image as a numpy array. Can be any number of channels. |
Returns:
Type | Description |
---|---|
np.ndarray | Grayscale image as a 2D numpy array. The output data type matches the input data type. |
Image types: uint8, float32
Number of channels: any
Source code in albumentations/augmentations/functional.py
def to_gray_average(img: np.ndarray) -> np.ndarray:
"""Convert an image to grayscale using the average method.
This function computes the arithmetic mean across all channels for each pixel,
resulting in a grayscale representation of the image.
Key aspects of this method:
1. It treats all channels equally, regardless of their perceptual importance.
2. Works with any number of channels, making it versatile for various image types.
3. Simple and fast to compute, but may not accurately represent perceived brightness.
4. For RGB images, the formula is: Gray = (R + G + B) / 3
Note: This method may produce different results compared to weighted methods
(like RGB weighted average) which account for human perception of color brightness.
It may also produce unexpected results for images with alpha channels or
non-color data in additional channels.
Args:
img (np.ndarray): Input image as a numpy array. Can be any number of channels.
Returns:
np.ndarray: Grayscale image as a 2D numpy array. The output data type
matches the input data type.
Image types:
uint8, float32
Number of channels:
any
"""
return np.mean(img, axis=-1).astype(img.dtype)
def to_gray_desaturation (img)
[view source on GitHub]¶
Convert an image to grayscale using the desaturation method.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image as a numpy array. |
Returns:
Type | Description |
---|---|
np.ndarray | Grayscale image as a 2D numpy array. |
Image types: uint8, float32
Number of channels: any
Source code in albumentations/augmentations/functional.py
@clipped
def to_gray_desaturation(img: np.ndarray) -> np.ndarray:
"""Convert an image to grayscale using the desaturation method.
Args:
img (np.ndarray): Input image as a numpy array.
Returns:
np.ndarray: Grayscale image as a 2D numpy array.
Image types:
uint8, float32
Number of channels:
any
"""
float_image = img.astype(np.float32)
return (np.max(float_image, axis=-1) + np.min(float_image, axis=-1)) / 2
def to_gray_from_lab (img)
[view source on GitHub]¶
Convert an RGB image to grayscale using the L channel from the LAB color space.
This function converts the RGB image to the LAB color space and extracts the L channel. The LAB color space is designed to approximate human vision, where L represents lightness.
Key aspects of this method: 1. The L channel represents the lightness of each pixel, ranging from 0 (black) to 100 (white). 2. It's more perceptually uniform than RGB, meaning equal changes in L values correspond to roughly equal changes in perceived lightness. 3. The L channel is independent of the color information (A and B channels), making it suitable for grayscale conversion.
This method can be particularly useful when you want a grayscale image that closely matches human perception of lightness, potentially preserving more perceived contrast than simple RGB-based methods.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input RGB image as a numpy array. |
Returns:
Type | Description |
---|---|
np.ndarray | Grayscale image as a 2D numpy array, representing the L (lightness) channel. Values are scaled to match the input image's data type range. |
Image types: uint8, float32
Number of channels: 3
Source code in albumentations/augmentations/functional.py
@uint8_io
@clipped
def to_gray_from_lab(img: np.ndarray) -> np.ndarray:
"""Convert an RGB image to grayscale using the L channel from the LAB color space.
This function converts the RGB image to the LAB color space and extracts the L channel.
The LAB color space is designed to approximate human vision, where L represents lightness.
Key aspects of this method:
1. The L channel represents the lightness of each pixel, ranging from 0 (black) to 100 (white).
2. It's more perceptually uniform than RGB, meaning equal changes in L values correspond to
roughly equal changes in perceived lightness.
3. The L channel is independent of the color information (A and B channels), making it
suitable for grayscale conversion.
This method can be particularly useful when you want a grayscale image that closely
matches human perception of lightness, potentially preserving more perceived contrast
than simple RGB-based methods.
Args:
img (np.ndarray): Input RGB image as a numpy array.
Returns:
np.ndarray: Grayscale image as a 2D numpy array, representing the L (lightness) channel.
Values are scaled to match the input image's data type range.
Image types:
uint8, float32
Number of channels:
3
"""
return cv2.cvtColor(img, cv2.COLOR_RGB2LAB)[..., 0]
def to_gray_max (img)
[view source on GitHub]¶
Convert an image to grayscale using the maximum channel value method.
This function takes the maximum value across all channels for each pixel, resulting in a grayscale image that preserves the brightest parts of the original image.
Key aspects of this method: 1. Works with any number of channels, making it versatile for various image types. 2. For 3-channel (e.g., RGB) images, this method is equivalent to extracting the V (Value) channel from the HSV color space. 3. Preserves the brightest parts of the image but may lose some color contrast information. 4. Simple and fast to compute.
Note: - This method tends to produce brighter grayscale images compared to other conversion methods, as it always selects the highest intensity value from the channels. - For RGB images, it may not accurately represent perceived brightness as it doesn't account for human color perception.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image as a numpy array. Can be any number of channels. |
Returns:
Type | Description |
---|---|
np.ndarray | Grayscale image as a 2D numpy array. The output data type matches the input data type. |
Image types: uint8, float32
Number of channels: any
Source code in albumentations/augmentations/functional.py
def to_gray_max(img: np.ndarray) -> np.ndarray:
"""Convert an image to grayscale using the maximum channel value method.
This function takes the maximum value across all channels for each pixel,
resulting in a grayscale image that preserves the brightest parts of the original image.
Key aspects of this method:
1. Works with any number of channels, making it versatile for various image types.
2. For 3-channel (e.g., RGB) images, this method is equivalent to extracting the V (Value)
channel from the HSV color space.
3. Preserves the brightest parts of the image but may lose some color contrast information.
4. Simple and fast to compute.
Note:
- This method tends to produce brighter grayscale images compared to other conversion methods,
as it always selects the highest intensity value from the channels.
- For RGB images, it may not accurately represent perceived brightness as it doesn't
account for human color perception.
Args:
img (np.ndarray): Input image as a numpy array. Can be any number of channels.
Returns:
np.ndarray: Grayscale image as a 2D numpy array. The output data type
matches the input data type.
Image types:
uint8, float32
Number of channels:
any
"""
return np.max(img, axis=-1)
def to_gray_pca (img)
[view source on GitHub]¶
Convert an image to grayscale using Principal Component Analysis (PCA).
This function applies PCA to reduce a multi-channel image to a single channel, effectively creating a grayscale representation that captures the maximum variance in the color data.
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input image as a numpy array with shape (height, width, channels). |
Returns:
Type | Description |
---|---|
np.ndarray | Grayscale image as a 2D numpy array with shape (height, width). If input is uint8, output is uint8 in range [0, 255]. If input is float32, output is float32 in range [0, 1]. |
Note
This method can potentially preserve more information from the original image compared to standard weighted average methods, as it accounts for the correlations between color channels.
Image types: uint8, float32
Number of channels: any
Source code in albumentations/augmentations/functional.py
@clipped
def to_gray_pca(img: np.ndarray) -> np.ndarray:
"""Convert an image to grayscale using Principal Component Analysis (PCA).
This function applies PCA to reduce a multi-channel image to a single channel,
effectively creating a grayscale representation that captures the maximum variance
in the color data.
Args:
img (np.ndarray): Input image as a numpy array with shape (height, width, channels).
Returns:
np.ndarray: Grayscale image as a 2D numpy array with shape (height, width).
If input is uint8, output is uint8 in range [0, 255].
If input is float32, output is float32 in range [0, 1].
Note:
This method can potentially preserve more information from the original image
compared to standard weighted average methods, as it accounts for the
correlations between color channels.
Image types:
uint8, float32
Number of channels:
any
"""
dtype = img.dtype
# Reshape the image to a 2D array of pixels
pixels = img.reshape(-1, img.shape[2])
# Perform PCA
pca = PCA(n_components=1)
pca_result = pca.fit_transform(pixels)
# Reshape back to image dimensions and scale to 0-255
grayscale = pca_result.reshape(img.shape[:2])
grayscale = normalize_per_image(grayscale, "min_max")
return from_float(grayscale, target_dtype=dtype) if dtype == np.uint8 else grayscale
def to_gray_weighted_average (img)
[view source on GitHub]¶
Convert an RGB image to grayscale using the weighted average method.
This function uses OpenCV's cvtColor function with COLOR_RGB2GRAY conversion, which applies the following formula: Y = 0.299R + 0.587G + 0.114*B
Parameters:
Name | Type | Description |
---|---|---|
img | np.ndarray | Input RGB image as a numpy array. |
Returns:
Type | Description |
---|---|
np.ndarray | Grayscale image as a 2D numpy array. |
Image types: uint8, float32
Number of channels: 3
Source code in albumentations/augmentations/functional.py
def to_gray_weighted_average(img: np.ndarray) -> np.ndarray:
"""Convert an RGB image to grayscale using the weighted average method.
This function uses OpenCV's cvtColor function with COLOR_RGB2GRAY conversion,
which applies the following formula:
Y = 0.299*R + 0.587*G + 0.114*B
Args:
img (np.ndarray): Input RGB image as a numpy array.
Returns:
np.ndarray: Grayscale image as a 2D numpy array.
Image types:
uint8, float32
Number of channels:
3
"""
return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)