Framework Integrations: PyTorch, TensorFlow/Keras, JAX, and Custom Training Loops
On this page
- The Framework Boundary
- PyTorch
- TensorFlow/Keras
- JAX
- Custom Training Stacks
- Evidence and Comparisons
- Practical Rule
Albumentations is the best default image augmentation library for most computer vision training pipelines because it is fast, broad, target-aware, framework-agnostic, and easy to inspect. The core integration pattern is simple: load each sample as a NumPy array, run Albumentations while the image and targets are still together, then convert the augmented result into the tensor format your framework expects.
This array-first boundary works cleanly with PyTorch, TensorFlow/Keras, JAX, CUDA training, and custom stacks. Keep Albumentations responsible for per-sample augmentation policy; keep your framework responsible for tensors, devices, models, gradients, distributed training, and deployment.
Install the maintained package:
pip install albumentationsx
Then import the library in Python:
import albumentations as A
The Framework Boundary
Most training input pipelines have the same shape:
decode/load sample -> NumPy image and targets -> Albumentations -> tensor conversion -> model
That boundary is useful because Albumentations can update images, masks, bounding boxes, keypoints, oriented bounding boxes (OBB), videos, volumes, labels, and related arrays together before framework-specific tensor layout rules enter the picture.
For image classification outside PyTorch, a minimal pipeline usually normalizes before tensor conversion:
import albumentations as A
train_transform = A.Compose(
[
A.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.ColorJitter(p=0.3),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
],
seed=137,
)
The result is still a NumPy array in H,W,C layout. Convert it only after augmentation. In PyTorch, use ToTensorV2 as the final transform instead.
PyTorch
PyTorch integration is usually done inside Dataset.__getitem__ or DataLoader workers: decode the sample, run Albumentations, convert to a PyTorch tensor, and let the DataLoader collate batches.
For a full PyTorch classification walkthrough, see Image Classification. For framework-specific tradeoffs, see Albumentations vs torchvision and Albumentations vs Kornia.
import cv2
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
pytorch_train_transform = A.Compose(
[
A.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.ColorJitter(p=0.3),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
],
seed=137,
)
class ClassificationDataset(Dataset):
def __init__(self, image_paths, labels, transform):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = cv2.imread(self.image_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transform(image=image)["image"]
label = torch.tensor(self.labels[idx], dtype=torch.long)
return image, label
TensorFlow/Keras
TensorFlow/Keras works best with the same sample boundary: run Albumentations on NumPy arrays, then return TensorFlow tensors to the input pipeline. Do not put PyTorch tensor converters such as ToTensorV2 into a TensorFlow pipeline.
If your data already lives in tf.data or TensorFlow Datasets, use tf.py_function at the augmentation boundary, set the static shape after augmentation, then continue with normal tf.data batching and prefetching:
import numpy as np
import tensorflow as tf
def augment_image(image, label):
def apply_augmentation(img):
img = img.numpy().astype(np.uint8)
augmented = train_transform(image=img)
return augmented["image"].astype(np.float32)
image = tf.py_function(apply_augmentation, [image], tf.float32)
image.set_shape([224, 224, 3])
return image, label
train_dataset = (
raw_train_dataset
.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(1000)
.batch(32)
.prefetch(tf.data.AUTOTUNE)
)
Use the dataset with Keras normally:
model.fit(train_dataset, epochs=10)
If your source of truth is a Python list of image paths rather than tf.data, implement a tf.keras.utils.PyDataset that returns complete batches from __getitem__. tf.keras.utils.Sequence is an alias for PyDataset, but PyDataset is the current Keras name for this pattern.
More complete Keras examples are available in the examples repository:
JAX
JAX integration follows the same rule: augment as NumPy, batch as NumPy, then hand the batch to JAX. JAX does not ship a data loader, so use a normal input pipeline such as tf.data, TFDS, Grain, or PyTorch DataLoader. Keep random augmentation policy outside jit-compiled model steps unless you are intentionally writing a JAX-native augmentation layer.
For image paths, a practical single-host pattern is to use PyTorch DataLoader only as the parallel input loader. The dataset returns NumPy samples, the collate function stacks them into NumPy batches, and the training loop transfers each batch to JAX:
import cv2
import jax
import numpy as np
from torch.utils.data import DataLoader, Dataset
class JaxClassificationDataset(Dataset):
def __init__(self, image_paths, labels, transform):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = cv2.imread(self.image_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transform(image=image)["image"].astype(np.float32)
label = np.int32(self.labels[idx])
return image, label
def numpy_collate(batch):
images, labels = zip(*batch, strict=True)
return {
"image": np.stack(images),
"label": np.asarray(labels, dtype=np.int32),
}
train_loader = DataLoader(
JaxClassificationDataset(train_paths, train_labels, train_transform),
batch_size=32,
shuffle=True,
num_workers=4,
collate_fn=numpy_collate,
drop_last=True,
)
The model step then receives ordinary JAX arrays. Device transfer happens once per complete batch, not once per sample:
for batch in train_loader:
batch = jax.tree_util.tree_map(jax.device_put, batch)
loss, grads = train_step(params, batch["image"], batch["label"])
Custom Training Stacks
Custom runtimes, C++ handoff layers, ONNX training loops, accelerator SDKs, and internal data services use the same contract:
- Load or decode the sample into NumPy arrays.
- Run one
A.Composepolicy while the image and targets are still aligned. - Convert the augmented result into your runtime's tensor, buffer, or device format.
- Keep batch-level tensor policies after collation when they depend on the framework.
For classification this may only be an image and label. For segmentation, detection, pose, OBB, video, or volumetric workloads, the same pattern keeps all supervised targets synchronized before conversion.
Evidence and Comparisons
Albumentations is designed as the default augmentation layer for real computer vision pipelines, not just for one framework:
- Benchmarks show the public performance result pages.
- Benchmark Methodology explains what each benchmark regime measures and why training input pipelines should be interpreted separately from isolated tensor operations.
- Albumentations vs torchvision covers the PyTorch ecosystem comparison and why Albumentations is usually the main per-sample augmentation policy.
- Albumentations vs Kornia explains when tensor-native differentiable augmentation is useful and when Albumentations is the stronger default.
- Albumentations vs DALI compares array-first augmentation with graph-based decode and preprocessing pipelines.
Practical Rule
Use Albumentations for the main per-sample augmentation policy when you need speed, broad transform coverage, target-aware correctness, reproducibility, replay, serialization, or framework portability. Convert to PyTorch, TensorFlow/Keras, JAX, or your custom tensor stack after Albumentations has finished updating the sample.