Stay updated
News & InsightsSemantic Segmentation with Pretrained U-Net in Keras + Albumentations 🔗
🎯 Overview 🔗
This notebook demonstrates state-of-the-art semantic segmentation using:
- Pretrained encoders from ImageNet (ResNet50, ResNet34, MobileNetV2, VGG16)
- U-Net decoder architecture
- Albumentations for professional data augmentation
- Automatic device detection (CUDA GPU / Apple MPS / CPU)
- Oxford-IIIT Pet dataset (segmenting pets from background)
⚠️ Important for Apple Silicon (M1/M2/M3) Users: 🔗
To enable GPU acceleration via Metal Performance Shaders (MPS), you need:
pip install tensorflow-metal
Without tensorflow-metal
, TensorFlow will use CPU only!
Key Features: 🔗
✅ Transfer learning from ImageNet weights
✅ Modern Keras 3.x implementation
✅ Segmentation-optimized augmentations
✅ Progress tracking with tqdm
✅ Expected Dice score: 0.75-0.85 with just 10-15 epochs
Why Pretrained Models? 🔗
- Faster convergence: 2-3x faster than training from scratch
- Better performance: Higher Dice scores with less data
- Transfer learning: Leverages ImageNet features for segmentation
📦 Installation 🔗
# Install required packages
%pip install -q --upgrade albumentationsx tensorflow tensorflow-datasets keras matplotlib tqdm jupytext
[33mWARNING: Ignoring invalid distribution -ensorflow (/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.10/site-packages)[0m[33m [0m[33mWARNING: Ignoring invalid distribution -ensorflow (/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.10/site-packages)[0m[33m [0m[33mWARNING: Ignoring invalid distribution -ensorflow (/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.10/site-packages)[0m[33m [0mNote: you may need to restart the kernel to use updated packages.
🔧 Imports and Setup 🔗
import os
import platform
import warnings
from typing import Tuple
from dataclasses import dataclass
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import layers, models, optimizers, callbacks
import albumentations as A
import matplotlib.pyplot as plt
from tqdm.keras import TqdmCallback
from tqdm import tqdm
# Display plots inline
%matplotlib inline
# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
print("Versions:")
print(f" TensorFlow: {tf.__version__}")
print(f" Keras: {keras.__version__}")
print(f" Albumentations: {A.__version__}")
print(f" Platform: {platform.system()} {platform.processor()}")
Versions: TensorFlow: 2.19.1 Keras: 3.11.3 Albumentations: 2.0.11 Platform: Darwin arm
🖥️ Automatic Device Detection 🔗
Automatically detects and configures the best available compute device:
def setup_device():
"""Setup and detect available compute device."""
import platform
system = platform.system()
processor = platform.processor()
# Check for GPU availability
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
# Enable memory growth for GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# Determine if it's CUDA or MPS
if system == 'Darwin':
print("✓ Apple Silicon MPS GPU detected")
print(" TensorFlow will use Metal Performance Shaders")
return 'MPS'
else:
device_name = f"CUDA GPU ({len(gpus)} device(s))"
print(f"✓ Using {device_name}")
try:
gpu_details = tf.config.experimental.get_device_details(gpus[0])
print(f" GPU Name: {gpu_details.get('device_name', 'Unknown')}")
except:
pass
return 'GPU'
except RuntimeError as e:
print(f"GPU setup failed: {e}")
# Check if on Apple Silicon without GPU detected
if system == 'Darwin' and processor == 'arm':
print("⚠️ Apple Silicon detected but MPS not available!")
print(" To enable GPU acceleration, install tensorflow-metal:")
print(" pip install tensorflow-metal")
print(" Currently using CPU - training will be slower")
return 'CPU'
# Regular CPU fallback
print("ℹ Using CPU (no GPU detected)")
print(" Training will be slower. Consider using Google Colab with GPU.")
return 'CPU'
# Detect and setup device
device_type = setup_device()
✓ Apple Silicon MPS GPU detected TensorFlow will use Metal Performance Shaders
⚙️ Configuration 🔗
Configure training parameters and model settings:
@dataclass
class Config:
"""Configuration for training."""
# Model
backbone: str = 'resnet50' # Options: 'resnet34', 'resnet50', 'mobilenetv2', 'vgg16'
input_shape: Tuple[int, int, int] = (256, 256, 3)
num_classes: int = 3 # Background, pet, border
# Training
batch_size: int = 8 # Adjust based on your GPU memory
epochs: int = 10 # Pretrained models converge faster
learning_rate: float = 5e-4
# Dataset splits
train_split: str = "train[:80%]"
val_split: str = "train[80%:90%]"
test_split: str = "train[90%:]"
# Paths
checkpoint_dir: str = 'checkpoints'
# Create configuration
config = Config()
print("Training Configuration:")
print(f" Backbone: {config.backbone.upper()} (pretrained on ImageNet)")
print(f" Input shape: {config.input_shape}")
print(f" Batch size: {config.batch_size}")
print(f" Epochs: {config.epochs}")
print(f" Learning rate: {config.learning_rate}")
print(f" Classes: {config.num_classes} (background, pet, border)")
# Adjust batch size based on device
if device_type == 'CPU':
config.batch_size = 4
print(f" → Reduced batch size to {config.batch_size} for CPU")
Training Configuration: Backbone: RESNET50 (pretrained on ImageNet) Input shape: (256, 256, 3) Batch size: 8 Epochs: 10 Learning rate: 0.0005 Classes: 3 (background, pet, border)
🏗️ U-Net Model Architecture 🔗
Build U-Net with pretrained ImageNet encoder:
def build_unet_model(backbone_name: str, input_shape: Tuple, num_classes: int) -> keras.Model:
"""
Build U-Net with pretrained backbone from ImageNet.
The U-Net architecture consists of:
1. Encoder: Pretrained backbone (ResNet, MobileNet, or VGG)
2. Decoder: Transposed convolutions with skip connections
3. Output: Softmax activation for multi-class segmentation
"""
inputs = layers.Input(shape=input_shape)
# Select pretrained encoder based on backbone choice
if backbone_name == 'resnet34':
print("Loading ResNet34 pretrained on ImageNet...")
# Keras doesn't have ResNet34, so we use ResNet50V2
encoder = tf.keras.applications.ResNet50V2(
input_tensor=inputs,
weights='imagenet',
include_top=False
)
skip_layer_names = ['conv1_conv', 'conv2_block3_1_relu',
'conv3_block4_1_relu', 'conv4_block6_1_relu']
print(" (Using ResNet50V2 as ResNet34 equivalent)")
elif backbone_name == 'resnet50':
print("Loading ResNet50 pretrained on ImageNet...")
encoder = tf.keras.applications.ResNet50(
input_tensor=inputs,
weights='imagenet',
include_top=False
)
skip_layer_names = ['conv1_relu', 'conv2_block3_out',
'conv3_block4_out', 'conv4_block6_out']
elif backbone_name == 'mobilenetv2':
print("Loading MobileNetV2 pretrained on ImageNet...")
encoder = tf.keras.applications.MobileNetV2(
input_tensor=inputs,
weights='imagenet',
include_top=False
)
skip_layer_names = ['block_1_expand_relu', 'block_3_expand_relu',
'block_6_expand_relu', 'block_13_expand_relu']
elif backbone_name == 'vgg16':
print("Loading VGG16 pretrained on ImageNet...")
encoder = tf.keras.applications.VGG16(
input_tensor=inputs,
weights='imagenet',
include_top=False
)
skip_layer_names = ['block1_pool', 'block2_pool',
'block3_pool', 'block4_pool']
else:
raise ValueError(f"Unknown backbone: {backbone_name}")
# Make encoder trainable for fine-tuning
encoder.trainable = True
# Get skip connections from encoder layers
skip_layers = [encoder.get_layer(name).output for name in skip_layer_names]
# Build U-Net decoder
x = encoder.output
decoder_filters = [256, 128, 64, 32]
# Decoder blocks with skip connections
for i, (skip, filters) in enumerate(zip(reversed(skip_layers), decoder_filters)):
# Upsampling
x = layers.Conv2DTranspose(filters, 3, strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
# Skip connection (concatenate)
x = layers.Concatenate()([x, skip])
# Double convolution block
x = layers.Conv2D(filters, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
# Dropout for regularization (except last block)
if filters > 32:
x = layers.Dropout(0.3)(x)
# Final upsampling to match input resolution
x = layers.Conv2DTranspose(16, 3, strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
# Output layer with softmax for multi-class segmentation
outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)
model = models.Model(inputs=inputs, outputs=outputs, name=f'{backbone_name}_unet')
print(f"Model created: {model.count_params():,} parameters")
return model
🎨 Albumentations Augmentation Pipeline 🔗
Create segmentation-optimized augmentations with proper normalization:
def create_augmentations(training: bool = True, backbone_name: str = 'resnet50'):
"""Create Albumentations pipeline optimized for segmentation.
Key principles for segmentation augmentations:
- Avoid aggressive geometric distortions (no ElasticTransform, GridDistortion)
- Use conservative spatial transforms (mild rotation, scaling)
- Apply color augmentations for robustness
- Include normalization as the last step
"""
# Get normalization parameters for each backbone
if backbone_name in ['resnet34', 'resnet50', 'vgg16']:
# Standard ImageNet normalization
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
elif backbone_name == 'mobilenetv2':
# MobileNetV2 normalizes to [-1, 1]
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
else:
# Default ImageNet normalization
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
if training:
return A.Compose([
# Spatial augmentations (conservative for segmentation)
A.RandomCrop(height=256, width=256, pad_if_needed=True),
A.HorizontalFlip(p=0.5),
A.Affine(
scale=(0.5, 2), # Moderate scaling
rotate=(-15, 15), # Mild rotation
balanced_scale=True,
keep_ratio=True,
p=0.5
),
# Color augmentations
A.ColorJitter(
brightness=[0.8, 1.2],
contrast=[0.8, 1.2],
saturation=[0.8, 1.2],
hue=[-0.5, 0.5],
p=1
), # Random brightness, contrast, saturation, hue
# Noise and blur (mild)
A.OneOf([
A.GaussNoise(std_range=(0.1, 0.2), p=1),
A.GaussianBlur(blur_limit=(3, 5), p=1),
], p=0.2),
# Normalization (must be last!)
A.Normalize(mean=mean, std=std),
], seed=137, strict=True) # seed for reproducibility
else:
return A.Compose([
# For validation/test: only resize and normalize
A.CenterCrop(height=256, width=256, pad_if_needed=True, p=1),
A.Normalize(mean=mean, std=std),
])
# Preview augmentation settings
print("Augmentation Pipeline:")
print(" Training:")
print(" - RandomCrop with padding")
print(" - HorizontalFlip (50%)")
print(" - Affine transforms (scale, rotate)")
print(" - ColorJitter")
print(" - Noise/Blur (20%)")
print(" - Normalization")
print(" Validation:")
print(" - CenterCrop with padding")
print(" - Normalization")
Augmentation Pipeline: Training:
- RandomCrop with padding
- HorizontalFlip (50%)
- Affine transforms (scale, rotate)
- ColorJitter
- Noise/Blur (20%)
- Normalization Validation:
- CenterCrop with padding
- Normalization
📊 Loss Functions and Metrics 🔗
Define segmentation-specific loss functions and metrics:
def dice_coefficient(y_true, y_pred, smooth=1.0):
"""Dice coefficient metric for segmentation.
Dice = 2 * |A ∩ B| / (|A| + |B|)
Range: [0, 1] where 1 is perfect overlap
"""
y_true_f = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
y_pred_f = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
union = tf.reduce_sum(y_true_f, axis=0) + tf.reduce_sum(y_pred_f, axis=0)
dice = tf.reduce_mean((2. * intersection + smooth) / (union + smooth))
return dice
def dice_loss(y_true, y_pred):
"""Dice loss for training."""
return 1.0 - dice_coefficient(y_true, y_pred)
def combined_loss(y_true, y_pred):
"""Combined dice + categorical crossentropy loss.
This combination helps with:
- Dice: Handles class imbalance
- CrossEntropy: Provides stable gradients
"""
cce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
return dice_loss(y_true, y_pred) + tf.reduce_mean(cce)
class MeanIoU(tf.keras.metrics.MeanIoU):
"""Mean IoU metric for one-hot encoded masks."""
def update_state(self, y_true, y_pred, sample_weight=None):
# Convert from one-hot to class indices
y_true = tf.argmax(y_true, axis=-1)
y_pred = tf.argmax(y_pred, axis=-1)
return super().update_state(y_true, y_pred, sample_weight)
print("Loss and Metrics:")
print(" Loss: Combined (Dice + Categorical Crossentropy)")
print(" Metrics: Dice Coefficient, Mean IoU, Categorical Accuracy")
Loss and Metrics: Loss: Combined (Dice + Categorical Crossentropy) Metrics: Dice Coefficient, Mean IoU, Categorical Accuracy
📁 Dataset Loading 🔗
Load Oxford-IIIT Pet dataset for semantic segmentation:
def load_dataset(config: Config):
"""Load Oxford-IIIT Pet dataset.
The dataset contains:
- Images of pets (cats and dogs)
- Segmentation masks with 3 classes:
0: Background
1: Pet
2: Border
"""
print("Loading Oxford-IIIT Pet dataset...")
print("This may take a few minutes on first run...")
with tqdm(total=3, desc="Loading splits", unit="split") as pbar:
(ds_train, ds_val, ds_test), ds_info = tfds.load(
'oxford_iiit_pet',
split=[config.train_split, config.val_split, config.test_split],
with_info=True
)
pbar.update(3)
print(f"\n✓ Dataset loaded successfully")
print(f" Train samples: {len(ds_train)}")
print(f" Val samples: {len(ds_val)}")
print(f" Test samples: {len(ds_test)}")
return ds_train, ds_val, ds_test, ds_info
# Load the dataset
ds_train, ds_val, ds_test, ds_info = load_dataset(config)
Loading Oxford-IIIT Pet dataset... This may take a few minutes on first run...
Loading splits: 100%|██████████| 3/3 [00:00<00:00, 41.73split/s]
✓ Dataset loaded successfully Train samples: 2944 Val samples: 368 Test samples: 368
🔄 Data Pipeline 🔗
Create efficient tf.data pipeline with augmentations:
def preprocess_data(data_dict):
"""Preprocess raw data from TensorFlow datasets."""
image = tf.cast(data_dict['image'], tf.float32)
mask = data_dict['segmentation_mask']
# Convert mask to int32 and ensure 2D
mask = tf.cast(mask, tf.int32)
mask = tf.squeeze(mask) # Remove any extra dimensions
# Remap mask values for cleaner classes
# Original: 1=pet, 2=border, 3=background+border
# New: 0=background, 1=pet, 2=border
mask = tf.where(mask == 2, 0, mask) # border -> background
mask = tf.where(mask == 3, 2, mask) # background+border -> border
return image, mask
def create_data_pipeline(dataset, config: Config, augmentations, training: bool = True):
"""Create tf.data pipeline with augmentations."""
def augment_fn(image, mask):
"""Apply Albumentations augmentations."""
def aug(img, msk):
img = img.numpy().astype(np.uint8)
msk = msk.numpy().astype(np.uint8)
# Apply augmentations (including normalization)
augmented = augmentations(image=img, mask=msk)
return augmented['image'], augmented['mask'].astype(np.int32)
# Use tf.py_function to apply numpy-based augmentations
aug_img, aug_mask = tf.py_function(
aug, [image, mask], [tf.float32, tf.int32]
)
# Set shapes (required after py_function)
aug_img.set_shape([config.input_shape[0], config.input_shape[1], 3])
aug_mask.set_shape([config.input_shape[0], config.input_shape[1]])
# One-hot encode the mask for multi-class segmentation
aug_mask = tf.one_hot(aug_mask, config.num_classes)
return aug_img, aug_mask
# Build pipeline
dataset = dataset.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
if training:
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(config.batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
# Create augmentation pipelines
train_aug = create_augmentations(training=True, backbone_name=config.backbone)
val_aug = create_augmentations(training=False, backbone_name=config.backbone)
# Create data pipelines
print("\nCreating data pipelines...")
train_dataset = create_data_pipeline(ds_train, config, train_aug, training=True)
val_dataset = create_data_pipeline(ds_val, config, val_aug, training=False)
test_dataset = create_data_pipeline(ds_test, config, val_aug, training=False)
print("✓ Data pipelines ready")
Creating data pipelines... ✓ Data pipelines ready
👀 Visualize Data Samples 🔗
Preview augmented training samples:
def denormalize_image(img, backbone_name='resnet50'):
"""Properly denormalize image for visualization."""
# Get normalization parameters
if backbone_name in ['resnet34', 'resnet50', 'vgg16']:
# Standard ImageNet normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
elif backbone_name == 'mobilenetv2':
# MobileNetV2 normalizes to [-1, 1]
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
else:
# Default ImageNet normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
# Denormalize: x = (x_norm * std) + mean
img_denorm = (img * std) + mean
# Clip to [0, 1] range for valid image
img_denorm = np.clip(img_denorm, 0, 1)
return img_denorm
def visualize_samples(dataset, num_samples=3):
"""Visualize images and masks from dataset."""
fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples * 3))
if num_samples == 1:
axes = axes.reshape(1, -1)
# Get a batch of data
for images, masks in dataset.take(1):
for i in range(min(num_samples, images.shape[0])):
# Properly denormalize image for visualization
img = images[i].numpy()
img = denormalize_image(img, config.backbone)
# Convert one-hot mask to class indices
mask = tf.argmax(masks[i], axis=-1).numpy()
# Create colored mask for better visualization
colored_mask = np.zeros((*mask.shape, 3))
colored_mask[mask == 0] = [0, 0, 0] # Background: black
colored_mask[mask == 1] = [0, 1, 0] # Pet: green
colored_mask[mask == 2] = [1, 0, 0] # Border: red
# Plot image
axes[i, 0].imshow(img)
axes[i, 0].set_title("Input Image")
axes[i, 0].axis('off')
# Plot mask
axes[i, 1].imshow(colored_mask)
axes[i, 1].set_title("Segmentation Mask")
axes[i, 1].axis('off')
# Plot overlay
axes[i, 2].imshow(img)
axes[i, 2].imshow(colored_mask, alpha=0.5)
axes[i, 2].set_title("Overlay")
axes[i, 2].axis('off')
plt.suptitle("Training Samples with Albumentations Augmentations", fontsize=14)
plt.tight_layout()
plt.show()
# Visualize training samples
print("Training samples with augmentations:")
visualize_samples(train_dataset, num_samples=3)
Training samples with augmentations:

🚀 Build and Compile Model 🔗
# Build the model
print(f"Building {config.backbone.upper()} U-Net...\n")
model = build_unet_model(config.backbone, config.input_shape, config.num_classes)
# Compile the model
model.compile(
optimizer=optimizers.Adam(learning_rate=config.learning_rate),
loss=combined_loss,
metrics=[
dice_coefficient,
MeanIoU(num_classes=config.num_classes, name='mean_iou'),
'categorical_accuracy'
]
)
print("\n📊 Model Summary:")
print(f" Architecture: U-Net with {config.backbone.upper()} encoder")
print(f" Total parameters: {model.count_params():,}")
print(f" Loss: Combined (Dice + Categorical Crossentropy)")
print(f" Optimizer: Adam (lr={config.learning_rate})")
Building RESNET50 U-Net...
Loading ResNet50 pretrained on ImageNet... Model created: 33,387,043 parameters
📊 Model Summary: Architecture: U-Net with RESNET50 encoder Total parameters: 33,387,043 Loss: Combined (Dice + Categorical Crossentropy) Optimizer: Adam (lr=0.0005)
🏋️ Training 🔗
Train the model with callbacks for monitoring and early stopping:
# Setup callbacks
os.makedirs(config.checkpoint_dir, exist_ok=True)
checkpoint = callbacks.ModelCheckpoint(
filepath=os.path.join(config.checkpoint_dir, f'best_{config.backbone}_unet.keras'),
monitor='val_dice_coefficient',
mode='max',
save_best_only=True,
verbose=0
)
early_stop = callbacks.EarlyStopping(
monitor='val_dice_coefficient',
mode='max',
patience=7,
restore_best_weights=True,
verbose=0
)
reduce_lr = callbacks.ReduceLROnPlateau(
monitor='val_dice_coefficient',
mode='max',
factor=0.5,
patience=3,
min_lr=1e-7,
verbose=0
)
# Create tqdm callback for progress bars
tqdm_callback = TqdmCallback(verbose=2, leave=True)
# Train the model
print(f"\n🚂 Starting training for {config.epochs} epochs...")
print(f" Backbone: {config.backbone.upper()} (pretrained)")
print(f" Batch size: {config.batch_size}")
print(f" Learning rate: {config.learning_rate}\n")
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=config.epochs,
callbacks=[checkpoint, early_stop, reduce_lr, tqdm_callback],
verbose=0 # Disable default progress bar (using tqdm instead)
)
print("\n✅ Training complete!")
# Print best metrics
best_dice = max(history.history['val_dice_coefficient'])
best_iou = max(history.history['val_mean_iou'])
print(f"\n📊 Best validation metrics:")
print(f" Dice coefficient: {best_dice:.4f}")
print(f" Mean IoU: {best_iou:.4f}")
[A [A
🚂 Starting training for 10 epochs... Backbone: RESNET50 (pretrained) Batch size: 8 Learning rate: 0.0005
[A[A
📈 Training History 🔗
Visualize training metrics over epochs:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Loss
axes[0].plot(history.history['loss'], label='Train', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Val', linewidth=2)
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Dice Coefficient
axes[1].plot(history.history['dice_coefficient'], label='Train', linewidth=2)
axes[1].plot(history.history['val_dice_coefficient'], label='Val', linewidth=2)
axes[1].set_title('Dice Coefficient')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Dice')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Mean IoU
axes[2].plot(history.history['mean_iou'], label='Train', linewidth=2)
axes[2].plot(history.history['val_mean_iou'], label='Val', linewidth=2)
axes[2].set_title('Mean IoU')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('IoU')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.suptitle(f'Training History - {config.backbone.upper()} U-Net', fontsize=14)
plt.tight_layout()
plt.show()
# Print final metrics
print(f"\nFinal Training Metrics (Epoch {len(history.history['loss'])}):")
print(f" Train Dice: {history.history['dice_coefficient'][-1]:.4f}")
print(f" Val Dice: {history.history['val_dice_coefficient'][-1]:.4f}")
print(f" Train IoU: {history.history['mean_iou'][-1]:.4f}")
print(f" Val IoU: {history.history['val_mean_iou'][-1]:.4f}")
🎯 Model Evaluation 🔗
Evaluate model on test set:
# Evaluate on test set
print("🔍 Evaluating on test set...")
test_results = model.evaluate(
test_dataset,
verbose=0,
callbacks=[TqdmCallback(verbose=1, leave=True)]
)
print("\n📊 Test Set Results:")
print(f" Loss: {test_results[0]:.4f}")
print(f" Dice Coefficient: {test_results[1]:.4f}")
print(f" Mean IoU: {test_results[2]:.4f}")
print(f" Categorical Accuracy: {test_results[3]:.4f}")
# Performance interpretation
if test_results[1] > 0.8:
print("\n✨ Excellent performance! The pretrained backbone works great!")
elif test_results[1] > 0.7:
print("\n✅ Good performance! Transfer learning is effective!")
elif test_results[1] > 0.6:
print("\n📈 Decent results. Consider training for more epochs.")
else:
print("\n⚠️ Room for improvement. Try adjusting hyperparameters or augmentations.")
🔍 Visualize Predictions 🔗
Compare model predictions with ground truth:
def visualize_predictions(model, dataset, num_samples=3):
"""Visualize model predictions vs ground truth."""
fig, axes = plt.subplots(num_samples, 4, figsize=(15, num_samples * 3.5))
if num_samples == 1:
axes = axes.reshape(1, -1)
# Get predictions
for images, masks_true in dataset.take(1):
masks_pred = model.predict(images, verbose=0)
for i in range(min(num_samples, images.shape[0])):
# Denormalize image
img = images[i].numpy()
img = denormalize_image(img, config.backbone)
# Convert masks to class indices
mask_true = tf.argmax(masks_true[i], axis=-1).numpy()
mask_pred = tf.argmax(masks_pred[i], axis=-1).numpy()
# Create colored masks
def create_colored_mask(mask):
colored = np.zeros((*mask.shape, 3))
colored[mask == 0] = [0, 0, 0] # Background: black
colored[mask == 1] = [0, 1, 0] # Pet: green
colored[mask == 2] = [1, 0, 0] # Border: red
return colored
colored_true = create_colored_mask(mask_true)
colored_pred = create_colored_mask(mask_pred)
# Calculate accuracy for this sample
accuracy = np.mean(mask_true == mask_pred)
# Plot
axes[i, 0].imshow(img)
axes[i, 0].set_title("Input Image")
axes[i, 0].axis('off')
axes[i, 1].imshow(colored_true)
axes[i, 1].set_title("Ground Truth")
axes[i, 1].axis('off')
axes[i, 2].imshow(colored_pred)
axes[i, 2].set_title(f"Prediction (Acc: {accuracy:.2%})")
axes[i, 2].axis('off')
axes[i, 3].imshow(img)
axes[i, 3].imshow(colored_pred, alpha=0.5)
axes[i, 3].set_title("Overlay")
axes[i, 3].axis('off')
plt.suptitle(f'Model Predictions - {config.backbone.upper()} U-Net', fontsize=14)
plt.tight_layout()
plt.show()
# Visualize predictions
print("Test set predictions:")
visualize_predictions(model, test_dataset, num_samples=3)
💾 Save Model 🔗
Save the trained model for future use:
# Save the model
model_path = os.path.join(config.checkpoint_dir, f'final_{config.backbone}_unet.keras')
model.save(model_path)
print(f"✅ Model saved to: {model_path}")
print(f" Size: {os.path.getsize(model_path) / (1024*1024):.2f} MB")
# To load the model later:
print("\nTo load this model later, use:")
print(f"model = keras.models.load_model('{model_path}', custom_objects={{")
print(" 'dice_coefficient': dice_coefficient,")
print(" 'combined_loss': combined_loss,")
print(" 'MeanIoU': MeanIoU")
print("})")
🎉 Conclusion 🔗
What We Achieved: 🔗
- ✅ Built U-Net with pretrained ImageNet encoder
- ✅ Applied segmentation-optimized augmentations with Albumentations
- ✅ Trained on Oxford-IIIT Pet dataset
- ✅ Achieved good segmentation performance
Key Takeaways: 🔗
- Pretrained models significantly improve performance and convergence speed
- Proper augmentations are crucial for segmentation (avoid aggressive distortions)
- Combined loss (Dice + CrossEntropy) handles class imbalance well
- Albumentations provides powerful and efficient augmentation pipelines
Next Steps: 🔗
- Try different backbones (ResNet34, MobileNetV2, VGG16)
- Experiment with learning rates and batch sizes
- Train for more epochs for better performance
- Apply to your own segmentation dataset
Tips for Your Own Data: 🔗
- Adjust
num_classes
to match your dataset - Modify augmentations based on your domain (medical, satellite, etc.)
- Use class weights if you have severe class imbalance
- Consider using larger input sizes for fine details (384x384, 512x512)