Skip to content

Transforms (pytorch.transforms)

class ToTensorV2 (transpose_mask=False, p=1.0, always_apply=None) [view source on GitHub] ΒΆ

Converts images/masks to PyTorch Tensors, inheriting from BasicTransform. For images: - If input is in HWC format, converts to PyTorch CHW format - If input is in HW format, converts to PyTorch 1HW format (adds channel dimension)

Attributes:

Name Type Description
transpose_mask bool

If True, transposes 3D input mask dimensions from [height, width, num_channels] to [num_channels, height, width].

p float

Probability of applying the transform. Default: 1.0.

Interactive Tool Available!

Explore this transform visually and adjust parameters interactively using this tool:

Open Tool

Source code in albumentations/pytorch/transforms.py
Python
class ToTensorV2(BasicTransform):
    """Converts images/masks to PyTorch Tensors, inheriting from BasicTransform.
    For images:
        - If input is in `HWC` format, converts to PyTorch `CHW` format
        - If input is in `HW` format, converts to PyTorch `1HW` format (adds channel dimension)

    Attributes:
        transpose_mask (bool): If True, transposes 3D input mask dimensions from `[height, width, num_channels]` to
            `[num_channels, height, width]`.
        p (float): Probability of applying the transform. Default: 1.0.
    """

    _targets = (Targets.IMAGE, Targets.MASK)

    def __init__(self, transpose_mask: bool = False, p: float = 1.0, always_apply: bool | None = None):
        super().__init__(p=p, always_apply=always_apply)
        self.transpose_mask = transpose_mask

    @property
    def targets(self) -> dict[str, Any]:
        return {
            "image": self.apply,
            "images": self.apply_to_images,
            "mask": self.apply_to_mask,
            "masks": self.apply_to_masks,
        }

    def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor:
        if img.ndim not in {MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS}:
            msg = "Albumentations only supports images in HW or HWC format"
            raise ValueError(msg)

        if img.ndim == MONO_CHANNEL_DIMENSIONS:
            img = np.expand_dims(img, 2)

        return torch.from_numpy(img.transpose(2, 0, 1))

    def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor:
        if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
            mask = mask.transpose(2, 0, 1)
        return torch.from_numpy(mask)

    def apply_to_masks(self, masks: list[np.ndarray], **params: Any) -> list[torch.Tensor]:
        return [self.apply_to_mask(mask, **params) for mask in masks]

    def get_transform_init_args_names(self) -> tuple[str, ...]:
        return ("transpose_mask",)