Transforms (pytorch.transforms)¶
class ToTensor3D
(p=1.0, always_apply=None)
[view source on GitHub] ¶
Convert 3D volumes and masks to PyTorch tensors.
This transform is designed for 3D medical imaging data. It converts numpy arrays to PyTorch tensors and ensures consistent channel positioning.
For all inputs (volumes and masks): - Input: (D, H, W, C) or (D, H, W) - depth, height, width, [channels] - Output: (C, D, H, W) - channels first format for PyTorch For single-channel input, adds C=1 dimension
Note
This transform always moves channels to first position as this is the standard PyTorch format. For masks that need to stay in DHWC format, use a different transform or handle the transposition after this transform.
Parameters:
Name | Type | Description |
---|---|---|
p | float | Probability of applying the transform. Default: 1.0 |
Interactive Tool Available!
Explore this transform visually and adjust parameters interactively using this tool:
Source code in albumentations/pytorch/transforms.py
class ToTensor3D(BasicTransform):
"""Convert 3D volumes and masks to PyTorch tensors.
This transform is designed for 3D medical imaging data. It converts numpy arrays
to PyTorch tensors and ensures consistent channel positioning.
For all inputs (volumes and masks):
- Input: (D, H, W, C) or (D, H, W) - depth, height, width, [channels]
- Output: (C, D, H, W) - channels first format for PyTorch
For single-channel input, adds C=1 dimension
Note:
This transform always moves channels to first position as this is
the standard PyTorch format. For masks that need to stay in DHWC format,
use a different transform or handle the transposition after this transform.
Args:
p (float): Probability of applying the transform. Default: 1.0
"""
_targets = (Targets.IMAGE, Targets.MASK)
def __init__(self, p: float = 1.0, always_apply: bool | None = None):
super().__init__(p=p, always_apply=always_apply)
@property
def targets(self) -> dict[str, Any]:
return {
"volume": self.apply_to_volume,
"mask3d": self.apply_to_mask3d,
}
def apply_to_volume(self, volume: np.ndarray, **params: Any) -> torch.Tensor:
"""Convert 3D volume to channels-first tensor."""
if volume.ndim == NUM_VOLUME_DIMENSIONS: # D,H,W,C
return torch.from_numpy(volume.transpose(3, 0, 1, 2))
if volume.ndim == NUM_VOLUME_DIMENSIONS - 1: # D,H,W
return torch.from_numpy(volume[np.newaxis, ...])
raise ValueError(f"Expected 3D or 4D array (D,H,W) or (D,H,W,C), got {volume.ndim}D array")
def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> torch.Tensor:
"""Convert 3D mask to channels-first tensor."""
return self.apply_to_volume(mask3d, **params)
def get_transform_init_args_names(self) -> tuple[str, ...]:
return ()
class ToTensorV2
(transpose_mask=False, p=1.0, always_apply=None)
[view source on GitHub] ¶
Converts images/masks to PyTorch Tensors, inheriting from BasicTransform. For images: - If input is in HWC
format, converts to PyTorch CHW
format - If input is in HW
format, converts to PyTorch 1HW
format (adds channel dimension)
Attributes:
Name | Type | Description |
---|---|---|
transpose_mask | bool | If True, transposes 3D input mask dimensions from |
p | float | Probability of applying the transform. Default: 1.0. |
Interactive Tool Available!
Explore this transform visually and adjust parameters interactively using this tool:
Source code in albumentations/pytorch/transforms.py
class ToTensorV2(BasicTransform):
"""Converts images/masks to PyTorch Tensors, inheriting from BasicTransform.
For images:
- If input is in `HWC` format, converts to PyTorch `CHW` format
- If input is in `HW` format, converts to PyTorch `1HW` format (adds channel dimension)
Attributes:
transpose_mask (bool): If True, transposes 3D input mask dimensions from `[height, width, num_channels]` to
`[num_channels, height, width]`.
p (float): Probability of applying the transform. Default: 1.0.
"""
_targets = (Targets.IMAGE, Targets.MASK)
def __init__(self, transpose_mask: bool = False, p: float = 1.0, always_apply: bool | None = None):
super().__init__(p=p, always_apply=always_apply)
self.transpose_mask = transpose_mask
@property
def targets(self) -> dict[str, Any]:
return {
"image": self.apply,
"images": self.apply_to_images,
"mask": self.apply_to_mask,
"masks": self.apply_to_masks,
}
def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor:
if img.ndim not in {MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS}:
msg = "Albumentations only supports images in HW or HWC format"
raise ValueError(msg)
if img.ndim == MONO_CHANNEL_DIMENSIONS:
img = np.expand_dims(img, 2)
return torch.from_numpy(img.transpose(2, 0, 1))
def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor:
if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
mask = mask.transpose(2, 0, 1)
return torch.from_numpy(mask)
@overload
def apply_to_masks(self, masks: list[np.ndarray], **params: Any) -> list[torch.Tensor]: ...
@overload
def apply_to_masks(self, masks: np.ndarray, **params: Any) -> torch.Tensor: ...
def apply_to_masks(self, masks: np.ndarray | list[np.ndarray], **params: Any) -> torch.Tensor | list[torch.Tensor]:
"""Convert numpy array or list of numpy array masks to torch tensor(s).
Args:
masks: Numpy array of shape (N, H, W) or (N, H, W, C),
or a list of numpy arrays with shape (H, W) or (H, W, C).
params: Additional parameters.
Returns:
If transpose_mask is True and input is (N, H, W, C), returns tensor of shape (N, C, H, W).
If transpose_mask is True and input is (H, W, C), returns a list of tensors with shape (C, H, W).
Otherwise, returns tensors with the same shape as input.
"""
if isinstance(masks, list):
return [self.apply_to_mask(mask, **params) for mask in masks]
if self.transpose_mask and masks.ndim == NUM_VOLUME_DIMENSIONS: # (N, H, W, C)
masks = np.transpose(masks, (0, 3, 1, 2)) # -> (N, C, H, W)
return torch.from_numpy(masks)
def apply_to_images(self, images: np.ndarray, **params: Any) -> torch.Tensor:
"""Convert batch of images from (N, H, W, C) to (N, C, H, W)."""
if images.ndim != NUM_VOLUME_DIMENSIONS: # N,H,W,C
raise ValueError(f"Expected 4D array (N,H,W,C), got {images.ndim}D array")
return torch.from_numpy(images.transpose(0, 3, 1, 2)) # -> (N,C,H,W)
def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("transpose_mask",)