Mixing transforms (augmentations.mixing.transforms)

Source code in albumentations/augmentations/mixing/
class MixUp(ReferenceBasedTransform):
    """Performs MixUp data augmentation, blending images, masks, and class labels with reference data.

    MixUp augmentation linearly combines an input (image, mask, and class label) with another set from a predefined
    reference dataset. The mixing degree is controlled by a parameter λ (lambda), sampled from a Beta distribution.
    This method is known for improving model generalization by promoting linear behavior between classes and
    smoothing decision boundaries.

        Zhang, H., Cisse, M., Dauphin, Y.N., and Lopez-Paz, D. (2018). mixup: Beyond Empirical Risk Minimization.
        In International Conference on Learning Representations.

        reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]]):
            A sequence or generator of dictionaries containing the reference data for mixing
            If None or an empty sequence is provided, no operation is performed and a warning is issued.
        read_fn (Callable[[ReferenceImage], Dict[str, Any]]):
            A function to process items from reference_data. It should accept items from reference_data
            and return a dictionary containing processed data:
                - The returned dictionary must include an 'image' key with a numpy array value.
                - It may also include 'mask', 'global_label' each associated with numpy array values.
            Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
        mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
            Defaults to "mix_coef".
        alpha (float):
            The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
            Higher values lead to more uniform mixing. Defaults to 0.4.
        p (float):
            The probability of applying the transformation. Defaults to 0.5.

        image, mask, global_label

    Image types:
        - uint8, float32

        - ValueError: If the alpha parameter is negative.
        - NotImplementedError: If the transform is applied to bounding boxes or keypoints.

        - If no reference data is provided, a warning is issued, and the transform acts as a no-op.
        - Notes if images are in float32 format, they should be within [0, 1] range.

    Example Usage:
        import albumentations as A
        import numpy as np
        from albumentations.core.types import ReferenceImage

        # Prepare reference data
        # Note: This code generates random reference data for demonstration purposes only.
        # In real-world applications, it's crucial to use meaningful and representative data.
        # The quality and relevance of your input data significantly impact the effectiveness
        # of the augmentation process. Ensure your data closely aligns with your specific
        # use case and application requirements.
        reference_data = [ReferenceImage(image=np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8),
                                         mask=np.random.randint(0, 4, (100, 100, 1), dtype=np.uint8),
                                         global_label=np.random.choice([0, 1], size=3)) for i in range(10)]

        # In this example, the lambda function simply returns its input, which works well for
        # data already in the expected format. For more complex scenarios, where the data might not be in
        # the required format or additional processing is needed, a more sophisticated function can be implemented.
        # Below is a hypothetical example where the input data is a file path, # and the function reads the image
        # file, converts it to a specific format, and possibly performs other preprocessing steps.

        # Example of a more complex read_fn that reads an image from a file path, converts it to RGB, and resizes it.
        # def custom_read_fn(file_path):
        #     from PIL import Image
        #     image ='RGB')
        #     image = image.resize((100, 100))  # Example resize, adjust as needed.
        #     return np.array(image)

        # aug = A.Compose([A.RandomRotate90(), A.MixUp(p=1, reference_data=reference_data, read_fn=lambda x: x)])

        # For simplicity, the original lambda function is used in this example.
        # Replace `lambda x: x` with `custom_read_fn`if you need to process the data more extensively.

        # Apply augmentations
        image = np.empty([100, 100, 3], dtype=np.uint8)
        mask = np.empty([100, 100], dtype=np.uint8)
        global_label = np.array([0, 1, 0])
        data = aug(image=image, global_label=global_label, mask=mask)
        transformed_image = data["image"]
        transformed_mask = data["mask"]
        transformed_global_label = data["global_label"]

        # Print applied mix coefficient
        print(data["mix_coef"])  # Output: e.g., 0.9991580344142427

    _targets = (Targets.IMAGE, Targets.MASK, Targets.GLOBAL_LABEL)

    class InitSchema(BaseTransformInitSchema):
        reference_data: Optional[Union[Generator[Any, None, None], Sequence[Any]]] = None
        read_fn: Callable[[ReferenceImage], Any]
        alpha: Annotated[float, Field(default=0.4, ge=0, le=1)]
        mix_coef_return_name: str = "mix_coef"

    def __init__(
        reference_data: Optional[Union[Generator[Any, None, None], Sequence[Any]]] = None,
        read_fn: Callable[[ReferenceImage], Any] = lambda x: {"image": x, "mask": None, "class_label": None},
        alpha: float = 0.4,
        mix_coef_return_name: str = "mix_coef",
        always_apply: bool = False,
        p: float = 0.5,
        super().__init__(always_apply, p)
        self.mix_coef_return_name = mix_coef_return_name

        self.read_fn = read_fn
        self.alpha = alpha

        if reference_data is None:
            warn("No reference data provided for MixUp. This transform will act as a no-op.")
            # Create an empty generator
            self.reference_data: List[Any] = []
        elif (
            isinstance(reference_data, types.GeneratorType)
            or isinstance(reference_data, Iterable)
            and not isinstance(reference_data, str)
            self.reference_data = reference_data  # type: ignore[assignment]
            msg = "reference_data must be a list, tuple, generator, or None."
            raise TypeError(msg)

    def apply(self, img: np.ndarray, mix_data: ReferenceImage, mix_coef: float, **params: Any) -> np.ndarray:
        if not mix_data:
            return img

        mix_img = mix_data["image"]

        if not is_grayscale_image(img) and img.shape != mix_img.shape:
            msg = "The shape of the reference image should be the same as the input image."
            raise ValueError(msg)

        return add_weighted(img, mix_coef, mix_img, 1 - mix_coef) if mix_img is not None else img

    def apply_to_mask(self, mask: np.ndarray, mix_data: ReferenceImage, mix_coef: float, **params: Any) -> np.ndarray:
        mix_mask = mix_data.get("mask")
        return add_weighted(mask, mix_coef, mix_mask, 1 - mix_coef) if mix_mask is not None else mask

    def apply_to_global_label(
        label: np.ndarray,
        mix_data: ReferenceImage,
        mix_coef: float,
        **params: Any,
    ) -> np.ndarray:
        mix_label = mix_data.get("global_label")
        if mix_label is not None and label is not None:
            return mix_coef * label + (1 - mix_coef) * mix_label
        return label

    def apply_to_bboxes(self, bboxes: Sequence[BoxType], mix_data: ReferenceImage, **params: Any) -> Sequence[BoxType]:
        msg = "MixUp does not support bounding boxes yet, feel free to submit pull request to"
        raise NotImplementedError(msg)

    def apply_to_keypoints(
        keypoints: Sequence[KeypointType],
        *args: Any,
        **params: Any,
    ) -> Sequence[KeypointType]:
        msg = "MixUp does not support keypoints yet, feel free to submit pull request to"
        raise NotImplementedError(msg)

    def get_transform_init_args_names(self) -> Tuple[str, ...]:
        return "reference_data", "alpha"

    def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]:
        mix_data = None
        # Check if reference_data is not empty and is a sequence (list, tuple, np.array)
        if isinstance(self.reference_data, Sequence) and not isinstance(self.reference_data, (str, bytes)):
            if len(self.reference_data) > 0:  # Additional check to ensure it's not empty
                mix_idx = random.randint(0, len(self.reference_data) - 1)
                mix_data = self.reference_data[mix_idx]
        # Check if reference_data is an iterator or generator
        elif isinstance(self.reference_data, Iterator):
                mix_data = next(self.reference_data)  # Attempt to get the next item
            except StopIteration:
                    "Reference data iterator/generator has been exhausted. "
                    "Further mixing augmentations will not be applied.",
                return {"mix_data": {}, "mix_coef": 1}

        # If mix_data is None or empty after the above checks, return default values
        if mix_data is None:
            return {"mix_data": {}, "mix_coef": 1}

        # If mix_data is not None, calculate mix_coef and apply read_fn
        mix_coef = beta(self.alpha, self.alpha)  # Assuming beta is defined elsewhere
        return {"mix_data": self.read_fn(mix_data), "mix_coef": mix_coef}

    def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -> Dict[str, Any]:
        res = super().apply_with_params(params, *args, **kwargs)
        if self.mix_coef_return_name:
            res[self.mix_coef_return_name] = params["mix_coef"]
        return res

