Open in Google ColabRun this notebook interactively

D4 Test-Time Augmentation for Aerial Semantic Segmentation

This notebook demonstrates Test-Time Augmentation (TTA) applied to semantic segmentation of top-down aerial imagery, and shows why the D4 dihedral symmetry group is the natural choice for this domain.

Dataset: restor/tcd — aerial imagery (10 cm/px) with binary tree-cover masks. Top-down view = no privileged orientation = D4 symmetry. Model: restor/tcd-segformer-mit-b2 — SegFormer fine-tuned on aerial tree cover.

It was checked on Feb 28, 2026. If something does not work => submit an issue to https://github.com/albumentations-team/AlbumentationsX/issues

What is Test-Time Augmentation?

During training, we apply random augmentations so the model learns features robust to various transformations. During inference, we usually just preprocess and run a single forward pass — one image in, one prediction out.

Test-Time Augmentation bridges that gap: at inference time, we create multiple augmented versions of the same input, pass each through the model, and aggregate the predictions. The idea is that if the model has learned to be approximately invariant to a transformation, averaging predictions across transformed views reduces variance and improves accuracy — with zero retraining.

                ┌─── Transform 1 ──→ Model ──→ Prediction 1 ───┐
                │                                               │
Input Image ────┼─── Transform 2 ──→ Model ──→ Prediction 2 ───┼──→ Average ──→ Final Prediction
                │                                               │
                └─── Transform 3 ──→ Model ──→ Prediction 3 ───┘

TTA for Segmentation: Equivariance

For classification, we want the network to be invariant — a cat is still a cat when flipped. For segmentation, we need equivariance: if we transform the input, the output mask should transform in the same way.

Formally, for segmentation network f and geometric transform T with inverse T⁻¹:

T⁻¹( f( T(x) ) ) = f(x)

This means the TTA recipe for segmentation is:

  1. Apply transform T to the image
  2. Run the model → get logits in transformed space
  3. Apply the inverse transform T⁻¹ to the logits to bring them back to the original coordinate space
  4. Average the aligned logit maps → argmax → final mask

AlbumentationsX 2.0.19 makes this trivial: every symmetric spatial transform now has an .inverse() method.

Why D4 for Aerial Imagery?

For natural street-level photos, only horizontal flip makes sense — the world has a gravitational prior (sky is up, ground is down). Rotating a street scene by 90° produces an unnatural image the model has never seen.

Top-down aerial imagery has no privileged orientation. A building, road, or field patch looks identical at 0°, 90°, 180°, 270°, and their reflections. This gives us the full D4 dihedral symmetry group — 8 elements:

ElementTransform
eIdentity (original)
r90Rotate 90°
r180Rotate 180°
r270Rotate 270°
hHorizontal flip
vVertical flip
tTranspose (reflect across main diagonal)
hvtReflect across anti-diagonal

All 8 transforms are semantically valid for aerial data — so we can use all of them for TTA.

We compare four modes in this notebook:

ModeViewsTransforms
Baseline1No TTA
HorizontalFlip TTA2identity + hflip
RandomRotate90 TTA40°, 90°, 180°, 270°
D4 TTA8all 8 dihedral elements

For more details, see the AlbumentationsX TTA documentation.

Install Dependencies

!pip install "albumentationsx[headless]" datasets transformers huggingface_hub torch torchvision tqdm -q

Imports

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from PIL import Image
from datasets import load_dataset
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
import albumentations as A
from albumentations.core.type_definitions import d4_group_elements
from tqdm import tqdm

%matplotlib inline

Configuration

DEVICE = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
BENCHMARK_IMAGES = 20  # number of images used for mIoU benchmark

# restor/tcd: binary (background, tree)
NUM_CLASSES = 2
CLASS_NAMES = ["background", "tree"]
CLASS_COLORS = [(80, 80, 80), (34, 139, 34)]

print(f"Device: {DEVICE}")
Device: mps

Load Pretrained Model from HuggingFace

We load restor/tcd-segformer-mit-b2 — SegFormer fine-tuned on aerial tree cover (binary: tree vs background). Top-view imagery has D4 symmetry.

processor = AutoImageProcessor.from_pretrained("restor/tcd-segformer-mit-b2")
model = SegformerForSemanticSegmentation.from_pretrained("restor/tcd-segformer-mit-b2")
model = model.to(DEVICE).eval()

print("Model loaded successfully.")
The image processor of type `SegformerImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 

Preprocessing and Inference Helper

SegFormer uses the HuggingFace processor for normalization and resizing. We pass numpy images (HWC uint8) and get logits at the model's output resolution, then upsample to match the input size for correct inverse-transform alignment.

@torch.no_grad()
def predict_logits(image_np: np.ndarray) -> torch.Tensor:
    """Run SegFormer on HWC uint8 image → [1, C, H, W] logits at input resolution."""
    from PIL import Image
    inputs = processor(images=Image.fromarray(image_np), return_tensors="pt")
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    outputs = model(**inputs)
    logits = outputs.logits  # [1, 2, h, w] at 1/4 resolution (binary: bg, tree)
    # Upsample to input size for correct inverse-transform alignment
    logits = torch.nn.functional.interpolate(
        logits, size=(image_np.shape[0], image_np.shape[1]),
        mode="bilinear", align_corners=False
    )
    return logits

Load Dataset from HuggingFace

restor/tcd — aerial imagery with binary tree-cover masks. We load a small subset from the test split.

ds = load_dataset("restor/tcd", split="test[:20]")
print(f"Dataset size: {len(ds)} images")
print(f"Columns: {ds.column_names}")
print(f"Image size: {np.array(ds[0]['image']).shape}")

Visualize a Sample

Let's look at one aerial image and its ground-truth mask to understand the data.

def mask_to_rgb(mask: np.ndarray, colors: list, num_classes: int = 150) -> np.ndarray:
    """Convert an integer class mask to an RGB image. Uses modulo for >len(colors) classes."""
    h, w = mask.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    mask = np.clip(mask.astype(np.int32), 0, num_classes - 1)
    for cls_idx in range(num_classes):
        if (mask == cls_idx).any():
            rgb[mask == cls_idx] = colors[cls_idx % len(colors)]
    return rgb


def make_legend(class_names, class_colors):
    return [
        mpatches.Patch(color=np.array(c) / 255.0, label=n)
        for n, c in zip(class_names, class_colors)
    ]


sample = ds[0]
sample_image = np.array(sample["image"].convert("RGB"))
ann = np.array(sample["annotation"])
# restor/tcd: annotation has instance IDs; binarize to (background=0, tree=1)
sample_mask = ((ann > 0).any(axis=-1) if ann.ndim == 3 else (ann > 0)).astype(np.uint8)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(sample_image)
axes[0].set_title("Aerial image", fontsize=13)
axes[0].axis("off")

axes[1].imshow(mask_to_rgb(sample_mask, CLASS_COLORS))
axes[1].set_title("Ground truth mask", fontsize=13)
axes[1].axis("off")

fig.legend(
    handles=make_legend(CLASS_NAMES, CLASS_COLORS),
    loc="lower center",
    ncol=4,
    fontsize=9,
    bbox_to_anchor=(0.5, -0.18),
)
plt.tight_layout()
plt.show()

Defining the Four TTA Modes

Each mode creates a set of (aug, inv_aug) transform pairs. For every augmented forward pass, we apply the inverse transform to the logit map to bring it back into the original coordinate space before averaging.

AlbumentationsX 2.0.19 introduces:

  • group_element argument on D4, HorizontalFlip, RandomRotate90 etc. — for deterministic selection of a specific symmetry element
  • .inverse() method — returns a transform that undoes the spatial effect

This means the full D4 TTA loop is just iterating over d4_group_elements.

from albumentations.core.type_definitions import d4_group_elements

# rotate90_elements covers the 4 pure rotations that are a subgroup of D4
rotate90_elements = ["e", "r90", "r180", "r270"]
hflip_elements    = ["e", "h"]

TTA_MODES = {
    "Baseline (1 view)": ["e"],
    "HorizontalFlip TTA (2 views)": hflip_elements,
    "RandomRotate90 TTA (4 views)": rotate90_elements,
    "D4 TTA (8 views)": list(d4_group_elements),
}

print("TTA modes:")
for name, elements in TTA_MODES.items():
    print(f"  {name}: {elements}")

TTA Inference Function

The key step for segmentation TTA is applying the inverse spatial transform to the logit map before averaging. If we forget this step and average in the transformed coordinate space, predictions will be misaligned and the result will be worse than baseline.

image  ──→  T(image)  ──→  model  ──→  logits_T  ──→  T⁻¹(logits_T)  ──→  average

We use aug.inverse() which returns the exact inverse transform for each D4 element.

def invert_logits(aug: A.D4, logits: torch.Tensor) -> torch.Tensor:
    """
    Apply the spatial inverse of `aug` to a logit map.

    logits: [1, C, H, W] tensor
    Returns: [1, C, H, W] tensor in the original coordinate space
    """
    inv = aug.inverse()
    # Process each class channel as a pseudo-image
    # We work with the logit volume as a multi-channel mask
    logits_np = logits.squeeze(0).cpu().numpy()  # [C, H, W]
    C, H, W = logits_np.shape

    # Stack channels as HWC for Albumentations
    logits_hwc = logits_np.transpose(1, 2, 0)  # [H, W, C]

    # Apply inverse transform — Albumentations handles arbitrary channel counts
    inv_result = inv(image=logits_hwc)["image"]  # [H, W, C]

    inv_tensor = torch.from_numpy(
        inv_result.transpose(2, 0, 1)  # [C, H, W]
    ).unsqueeze(0).to(logits.device)   # [1, C, H, W]
    return inv_tensor


@torch.no_grad()
def run_tta(image_np: np.ndarray, group_elements: list) -> np.ndarray:
    """
    Run TTA inference over the given D4 group elements.

    Returns: predicted class map as np.ndarray [H, W]
    """
    accumulated = None

    for element in group_elements:
        aug = A.D4(p=1.0, group_element=element)
        aug_image = aug(image=image_np)["image"]

        logits = predict_logits(aug_image)          # [1, C, H, W]
        inv_logits = invert_logits(aug, logits)     # [1, C, H, W] in original space

        if accumulated is None:
            accumulated = inv_logits
        else:
            accumulated = accumulated + inv_logits

    avg_logits = accumulated / len(group_elements)  # [1, C, H, W]
    pred = avg_logits.squeeze(0).argmax(0).cpu().numpy()  # [H, W]
    return pred

Visual Comparison on a Single Image

We run all four modes on the same aerial patch and display the predicted masks side by side alongside the ground truth.

image_np = sample_image.copy()
gt_mask  = sample_mask.copy()

mode_preds = {}
for mode_name, elements in TTA_MODES.items():
    mode_preds[mode_name] = run_tta(image_np, elements)

n_cols = 2 + len(TTA_MODES)  # image + GT + 4 modes
fig, axes = plt.subplots(1, n_cols, figsize=(4 * n_cols, 4))

axes[0].imshow(image_np)
axes[0].set_title("Aerial Image", fontsize=11)
axes[0].axis("off")

axes[1].imshow(mask_to_rgb(gt_mask, CLASS_COLORS))
axes[1].set_title("Ground Truth", fontsize=11)
axes[1].axis("off")

for ax, (mode_name, pred) in zip(axes[2:], mode_preds.items()):
    ax.imshow(mask_to_rgb(pred, CLASS_COLORS))
    ax.set_title(mode_name.split(" (")[0], fontsize=10)
    ax.axis("off")

fig.legend(
    handles=make_legend(CLASS_NAMES, CLASS_COLORS),
    loc="lower center",
    ncol=5,
    fontsize=8,
    bbox_to_anchor=(0.5, -0.18),
)
plt.suptitle("TTA Mode Comparison — Predicted Masks", fontsize=13, y=1.02)
plt.tight_layout()
plt.show()

mIoU Benchmark

We evaluate all four modes on a subset of images and compute mean Intersection over Union (mIoU) against the ground truth masks. This gives a quantitative picture of how much each additional set of views improves accuracy.

Expected trend: each mode should match or exceed the previous, with D4 TTA achieving the highest mIoU at the cost of 8× inference time.

def compute_miou(pred: np.ndarray, gt: np.ndarray, num_classes: int) -> float:
    """Compute mean IoU over active classes (ignoring classes absent from gt and pred)."""
    ious = []
    for cls in range(num_classes):
        pred_cls = pred == cls
        gt_cls   = gt   == cls
        if not gt_cls.any() and not pred_cls.any():
            continue  # skip classes absent from both
        intersection = (pred_cls & gt_cls).sum()
        union        = (pred_cls | gt_cls).sum()
        if union == 0:
            continue
        ious.append(intersection / union)
    return float(np.mean(ious)) if ious else 0.0


# Accumulate mIoU across BENCHMARK_IMAGES images
miou_totals = {name: 0.0 for name in TTA_MODES}
n_eval = min(BENCHMARK_IMAGES, len(ds))

for idx in tqdm(range(n_eval), desc="Benchmarking"):
    sample_i  = ds[idx]
    img_i     = np.array(sample_i["image"].convert("RGB"))
    ann_i = np.array(sample_i["annotation"])
    mask_i = ((ann_i > 0).any(axis=-1) if ann_i.ndim == 3 else (ann_i > 0)).astype(np.uint8)

    for mode_name, elements in TTA_MODES.items():
        pred_i = run_tta(img_i, elements)
        pred_i = np.clip(pred_i, 0, NUM_CLASSES - 1)
        miou_totals[mode_name] += compute_miou(pred_i, mask_i, NUM_CLASSES)

miou_results = {name: total / n_eval for name, total in miou_totals.items()}

print(f"\nmIoU over {n_eval} images\n")
print(f"{'Mode':<35} {'Views':>6} {'mIoU':>8}")
print("-" * 52)
views_map = {name: len(els) for name, els in TTA_MODES.items()}
for mode_name, miou in miou_results.items():
    print(f"{mode_name:<35} {views_map[mode_name]:>6} {miou:>8.4f}")

Results Summary

Let's visualise the accuracy vs. inference cost tradeoff.

labels = list(miou_results.keys())
short_labels = ["Baseline\n(1 view)", "HFlip\n(2 views)", "Rotate90\n(4 views)", "D4\n(8 views)"]
mious  = [miou_results[l] for l in labels]
views  = [views_map[l] for l in labels]

colors = ["#4393c3", "#74add1", "#fdae61", "#d73027"]

fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(short_labels, mious, color=colors, edgecolor="white", linewidth=0.8)

for bar, miou in zip(bars, mious):
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.002,
        f"{miou:.4f}",
        ha="center", va="bottom", fontsize=11, fontweight="bold",
    )

ax.set_ylabel("mIoU", fontsize=12)
ax.set_title(
    f"TTA Accuracy vs. Inference Cost\n(evaluated on {n_eval} aerial images)",
    fontsize=13,
)
ax.set_ylim(0, max(mious) * 1.12)
ax.grid(axis="y", alpha=0.3)
ax.spines[["top", "right"]].set_visible(False)

plt.tight_layout()
plt.show()

# Print gain over baseline
baseline_miou = mious[0]
print("\nGain over baseline:")
for label, miou, nv in zip(short_labels, mious, views):
    gain = (miou - baseline_miou) * 100
    label_clean = label.replace("\n", " ")
    print(f"  {label_clean:<22} +{gain:.2f}% mIoU  ({nv}x inference cost)")

Key Takeaways

  • TTA is not free accuracy — it is a deliberate accuracy/compute trade-off. Each additional view requires a full forward pass through the model. 8-view D4 TTA costs 8× the inference time of a single forward pass. Whether that trade-off is worth it depends on your latency budget and how much the accuracy gain matters for your application.
  • The improvement is almost always positive, but not guaranteed. When the task has genuine geometric symmetries and the model was trained with those augmentations, TTA reliably helps. When those conditions don't hold, gains can be negligible.
  • Equivariance is the key insight for segmentation. We must apply the inverse spatial transform to the logit map before averaging — otherwise predictions are misaligned and the result can be worse than baseline.
  • D4 is the natural group for top-down imagery. When there is no privileged orientation — as in aerial, satellite, and microscopy images — all 8 dihedral symmetries are valid, and using all of them maximises the accuracy gain per additional view.
  • AlbumentationsX 2.0.19 makes the implementation trivial: group_element selects a deterministic symmetry element, and .inverse() provides the exact inverse transform — no manual bookkeeping needed.

For street-level images, only HFlip is appropriate. For aerial imagery, D4 TTA is the principled choice — as long as you have the compute budget for 8× inference.

Further reading: AlbumentationsX TTA Documentation