Semantic Segmentation with Pretrained U-Net in Keras + Albumentations 🔗

🎯 Overview 🔗

This notebook demonstrates state-of-the-art semantic segmentation using:

  • Pretrained encoders from ImageNet (ResNet50, ResNet34, MobileNetV2, VGG16)
  • U-Net decoder architecture
  • Albumentations for professional data augmentation
  • Automatic device detection (CUDA GPU / Apple MPS / CPU)
  • Oxford-IIIT Pet dataset (segmenting pets from background)

⚠️ Important for Apple Silicon (M1/M2/M3) Users: 🔗

To enable GPU acceleration via Metal Performance Shaders (MPS), you need:

pip install tensorflow-metal

Without tensorflow-metal, TensorFlow will use CPU only!

Key Features: 🔗

✅ Transfer learning from ImageNet weights
✅ Modern Keras 3.x implementation
✅ Segmentation-optimized augmentations
✅ Progress tracking with tqdm
✅ Expected Dice score: 0.75-0.85 with just 10-15 epochs

Why Pretrained Models? 🔗

  • Faster convergence: 2-3x faster than training from scratch
  • Better performance: Higher Dice scores with less data
  • Transfer learning: Leverages ImageNet features for segmentation

📦 Installation 🔗

# Install required packages
%pip install -q --upgrade albumentationsx tensorflow tensorflow-datasets keras matplotlib tqdm jupytext

WARNING: Ignoring invalid distribution -ensorflow (/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.10/site-packages) WARNING: Ignoring invalid distribution -ensorflow (/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.10/site-packages) WARNING: Ignoring invalid distribution -ensorflow (/opt/homebrew/Caskroom/miniconda/base/envs/albumentations_examples/lib/python3.10/site-packages) Note: you may need to restart the kernel to use updated packages.

🔧 Imports and Setup 🔗

import os
import platform
import warnings
from typing import Tuple
from dataclasses import dataclass

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import layers, models, optimizers, callbacks
import albumentations as A
import matplotlib.pyplot as plt
from tqdm.keras import TqdmCallback
from tqdm import tqdm

# Display plots inline
%matplotlib inline

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("Versions:")
print(f"  TensorFlow: {tf.__version__}")
print(f"  Keras: {keras.__version__}")
print(f"  Albumentations: {A.__version__}")
print(f"  Platform: {platform.system()} {platform.processor()}")

Versions: TensorFlow: 2.19.1 Keras: 3.11.3 Albumentations: 2.0.11 Platform: Darwin arm

🖥️ Automatic Device Detection 🔗

Automatically detects and configures the best available compute device:

def setup_device():
    """Setup and detect available compute device."""
    import platform
    system = platform.system()
    processor = platform.processor()
    
    # Check for GPU availability
    gpus = tf.config.list_physical_devices('GPU')
    
    if gpus:
        try:
            # Enable memory growth for GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            
            # Determine if it's CUDA or MPS
            if system == 'Darwin':
                print("✓ Apple Silicon MPS GPU detected")
                print("  TensorFlow will use Metal Performance Shaders")
                return 'MPS'
            else:
                device_name = f"CUDA GPU ({len(gpus)} device(s))"
                print(f"✓ Using {device_name}")
                try:
                    gpu_details = tf.config.experimental.get_device_details(gpus[0])
                    print(f"  GPU Name: {gpu_details.get('device_name', 'Unknown')}")
                except:
                    pass
                return 'GPU'
        except RuntimeError as e:
            print(f"GPU setup failed: {e}")
    
    # Check if on Apple Silicon without GPU detected
    if system == 'Darwin' and processor == 'arm':
        print("⚠️ Apple Silicon detected but MPS not available!")
        print("  To enable GPU acceleration, install tensorflow-metal:")
        print("  pip install tensorflow-metal")
        print("  Currently using CPU - training will be slower")
        return 'CPU'
    
    # Regular CPU fallback
    print("ℹ Using CPU (no GPU detected)")
    print("  Training will be slower. Consider using Google Colab with GPU.")
    return 'CPU'

# Detect and setup device
device_type = setup_device()

✓ Apple Silicon MPS GPU detected TensorFlow will use Metal Performance Shaders

⚙️ Configuration 🔗

Configure training parameters and model settings:

@dataclass
class Config:
    """Configuration for training."""
    # Model
    backbone: str = 'resnet50'  # Options: 'resnet34', 'resnet50', 'mobilenetv2', 'vgg16'
    input_shape: Tuple[int, int, int] = (256, 256, 3)
    num_classes: int = 3  # Background, pet, border
    
    # Training
    batch_size: int = 8  # Adjust based on your GPU memory
    epochs: int = 10  # Pretrained models converge faster
    learning_rate: float = 5e-4
    
    # Dataset splits
    train_split: str = "train[:80%]"
    val_split: str = "train[80%:90%]"
    test_split: str = "train[90%:]"
    
    # Paths
    checkpoint_dir: str = 'checkpoints'

# Create configuration
config = Config()

print("Training Configuration:")
print(f"  Backbone: {config.backbone.upper()} (pretrained on ImageNet)")
print(f"  Input shape: {config.input_shape}")
print(f"  Batch size: {config.batch_size}")
print(f"  Epochs: {config.epochs}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Classes: {config.num_classes} (background, pet, border)")

# Adjust batch size based on device
if device_type == 'CPU':
    config.batch_size = 4
    print(f"  → Reduced batch size to {config.batch_size} for CPU")

Training Configuration: Backbone: RESNET50 (pretrained on ImageNet) Input shape: (256, 256, 3) Batch size: 8 Epochs: 10 Learning rate: 0.0005 Classes: 3 (background, pet, border)

🏗️ U-Net Model Architecture 🔗

Build U-Net with pretrained ImageNet encoder:

def build_unet_model(backbone_name: str, input_shape: Tuple, num_classes: int) -> keras.Model:
    """
    Build U-Net with pretrained backbone from ImageNet.
    
    The U-Net architecture consists of:
    1. Encoder: Pretrained backbone (ResNet, MobileNet, or VGG)
    2. Decoder: Transposed convolutions with skip connections
    3. Output: Softmax activation for multi-class segmentation
    """
    inputs = layers.Input(shape=input_shape)
    
    # Select pretrained encoder based on backbone choice
    if backbone_name == 'resnet34':
        print("Loading ResNet34 pretrained on ImageNet...")
        # Keras doesn't have ResNet34, so we use ResNet50V2
        encoder = tf.keras.applications.ResNet50V2(
            input_tensor=inputs,
            weights='imagenet',
            include_top=False
        )
        skip_layer_names = ['conv1_conv', 'conv2_block3_1_relu', 
                           'conv3_block4_1_relu', 'conv4_block6_1_relu']
        print("  (Using ResNet50V2 as ResNet34 equivalent)")
        
    elif backbone_name == 'resnet50':
        print("Loading ResNet50 pretrained on ImageNet...")
        encoder = tf.keras.applications.ResNet50(
            input_tensor=inputs,
            weights='imagenet',
            include_top=False
        )
        skip_layer_names = ['conv1_relu', 'conv2_block3_out', 
                           'conv3_block4_out', 'conv4_block6_out']
        
    elif backbone_name == 'mobilenetv2':
        print("Loading MobileNetV2 pretrained on ImageNet...")
        encoder = tf.keras.applications.MobileNetV2(
            input_tensor=inputs,
            weights='imagenet',
            include_top=False
        )
        skip_layer_names = ['block_1_expand_relu', 'block_3_expand_relu',
                           'block_6_expand_relu', 'block_13_expand_relu']
        
    elif backbone_name == 'vgg16':
        print("Loading VGG16 pretrained on ImageNet...")
        encoder = tf.keras.applications.VGG16(
            input_tensor=inputs,
            weights='imagenet',
            include_top=False
        )
        skip_layer_names = ['block1_pool', 'block2_pool', 
                           'block3_pool', 'block4_pool']
    else:
        raise ValueError(f"Unknown backbone: {backbone_name}")
    
    # Make encoder trainable for fine-tuning
    encoder.trainable = True
    
    # Get skip connections from encoder layers
    skip_layers = [encoder.get_layer(name).output for name in skip_layer_names]
    
    # Build U-Net decoder
    x = encoder.output
    decoder_filters = [256, 128, 64, 32]
    
    # Decoder blocks with skip connections
    for i, (skip, filters) in enumerate(zip(reversed(skip_layers), decoder_filters)):
        # Upsampling
        x = layers.Conv2DTranspose(filters, 3, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        
        # Skip connection (concatenate)
        x = layers.Concatenate()([x, skip])
        
        # Double convolution block
        x = layers.Conv2D(filters, 3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        
        x = layers.Conv2D(filters, 3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        
        # Dropout for regularization (except last block)
        if filters > 32:
            x = layers.Dropout(0.3)(x)
    
    # Final upsampling to match input resolution
    x = layers.Conv2DTranspose(16, 3, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    # Output layer with softmax for multi-class segmentation
    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)
    
    model = models.Model(inputs=inputs, outputs=outputs, name=f'{backbone_name}_unet')
    print(f"Model created: {model.count_params():,} parameters")
    
    return model

🎨 Albumentations Augmentation Pipeline 🔗

Create segmentation-optimized augmentations with proper normalization:

def create_augmentations(training: bool = True, backbone_name: str = 'resnet50'):
    """Create Albumentations pipeline optimized for segmentation.
    
    Key principles for segmentation augmentations:
    - Avoid aggressive geometric distortions (no ElasticTransform, GridDistortion)
    - Use conservative spatial transforms (mild rotation, scaling)
    - Apply color augmentations for robustness
    - Include normalization as the last step
    """
    
    # Get normalization parameters for each backbone
    if backbone_name in ['resnet34', 'resnet50', 'vgg16']:
        # Standard ImageNet normalization
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
    elif backbone_name == 'mobilenetv2':
        # MobileNetV2 normalizes to [-1, 1]
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
    else:
        # Default ImageNet normalization
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
    
    if training:
        return A.Compose([
            # Spatial augmentations (conservative for segmentation)
            A.RandomCrop(height=256, width=256, pad_if_needed=True),
            A.HorizontalFlip(p=0.5),
            A.Affine(
                scale=(0.5, 2),  # Moderate scaling
                rotate=(-15, 15),  # Mild rotation
                balanced_scale=True,
                keep_ratio=True,
                p=0.5
            ),
            
            # Color augmentations
            A.ColorJitter(
                brightness=[0.8, 1.2],
                contrast=[0.8, 1.2],
                saturation=[0.8, 1.2],
                hue=[-0.5, 0.5],
                p=1
            ),  # Random brightness, contrast, saturation, hue
            
            # Noise and blur (mild)
            A.OneOf([
                A.GaussNoise(std_range=(0.1, 0.2), p=1),
                A.GaussianBlur(blur_limit=(3, 5), p=1),
            ], p=0.2),
            
            # Normalization (must be last!)
            A.Normalize(mean=mean, std=std),
        ], seed=137, strict=True)  # seed for reproducibility
    else:
        return A.Compose([
            # For validation/test: only resize and normalize
            A.CenterCrop(height=256, width=256, pad_if_needed=True, p=1),
            A.Normalize(mean=mean, std=std),
        ])

# Preview augmentation settings
print("Augmentation Pipeline:")
print("  Training:")
print("    - RandomCrop with padding")
print("    - HorizontalFlip (50%)")
print("    - Affine transforms (scale, rotate)")
print("    - ColorJitter")
print("    - Noise/Blur (20%)")
print("    - Normalization")
print("  Validation:")
print("    - CenterCrop with padding")
print("    - Normalization")

Augmentation Pipeline: Training:

  • RandomCrop with padding
  • HorizontalFlip (50%)
  • Affine transforms (scale, rotate)
  • ColorJitter
  • Noise/Blur (20%)
  • Normalization Validation:
  • CenterCrop with padding
  • Normalization

📊 Loss Functions and Metrics 🔗

Define segmentation-specific loss functions and metrics:

def dice_coefficient(y_true, y_pred, smooth=1.0):
    """Dice coefficient metric for segmentation.
    
    Dice = 2 * |A ∩ B| / (|A| + |B|)
    Range: [0, 1] where 1 is perfect overlap
    """
    y_true_f = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
    y_pred_f = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    union = tf.reduce_sum(y_true_f, axis=0) + tf.reduce_sum(y_pred_f, axis=0)
    dice = tf.reduce_mean((2. * intersection + smooth) / (union + smooth))
    return dice

def dice_loss(y_true, y_pred):
    """Dice loss for training."""
    return 1.0 - dice_coefficient(y_true, y_pred)

def combined_loss(y_true, y_pred):
    """Combined dice + categorical crossentropy loss.
    
    This combination helps with:
    - Dice: Handles class imbalance
    - CrossEntropy: Provides stable gradients
    """
    cce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return dice_loss(y_true, y_pred) + tf.reduce_mean(cce)

class MeanIoU(tf.keras.metrics.MeanIoU):
    """Mean IoU metric for one-hot encoded masks."""
    def update_state(self, y_true, y_pred, sample_weight=None):
        # Convert from one-hot to class indices
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

print("Loss and Metrics:")
print("  Loss: Combined (Dice + Categorical Crossentropy)")
print("  Metrics: Dice Coefficient, Mean IoU, Categorical Accuracy")

Loss and Metrics: Loss: Combined (Dice + Categorical Crossentropy) Metrics: Dice Coefficient, Mean IoU, Categorical Accuracy

📁 Dataset Loading 🔗

Load Oxford-IIIT Pet dataset for semantic segmentation:

def load_dataset(config: Config):
    """Load Oxford-IIIT Pet dataset.
    
    The dataset contains:
    - Images of pets (cats and dogs)
    - Segmentation masks with 3 classes:
      0: Background
      1: Pet
      2: Border
    """
    print("Loading Oxford-IIIT Pet dataset...")
    print("This may take a few minutes on first run...")
    
    with tqdm(total=3, desc="Loading splits", unit="split") as pbar:
        (ds_train, ds_val, ds_test), ds_info = tfds.load(
            'oxford_iiit_pet',
            split=[config.train_split, config.val_split, config.test_split],
            with_info=True
        )
        pbar.update(3)
    
    print(f"\n✓ Dataset loaded successfully")
    print(f"  Train samples: {len(ds_train)}")
    print(f"  Val samples: {len(ds_val)}")
    print(f"  Test samples: {len(ds_test)}")
    
    return ds_train, ds_val, ds_test, ds_info

# Load the dataset
ds_train, ds_val, ds_test, ds_info = load_dataset(config)

Loading Oxford-IIIT Pet dataset... This may take a few minutes on first run...

Loading splits: 100%|██████████| 3/3 [00:00<00:00, 41.73split/s]

✓ Dataset loaded successfully Train samples: 2944 Val samples: 368 Test samples: 368

🔄 Data Pipeline 🔗

Create efficient tf.data pipeline with augmentations:

def preprocess_data(data_dict):
    """Preprocess raw data from TensorFlow datasets."""
    image = tf.cast(data_dict['image'], tf.float32)
    mask = data_dict['segmentation_mask']
    
    # Convert mask to int32 and ensure 2D
    mask = tf.cast(mask, tf.int32)
    mask = tf.squeeze(mask)  # Remove any extra dimensions
    
    # Remap mask values for cleaner classes
    # Original: 1=pet, 2=border, 3=background+border
    # New: 0=background, 1=pet, 2=border
    mask = tf.where(mask == 2, 0, mask)  # border -> background
    mask = tf.where(mask == 3, 2, mask)  # background+border -> border
    
    return image, mask

def create_data_pipeline(dataset, config: Config, augmentations, training: bool = True):
    """Create tf.data pipeline with augmentations."""
    
    def augment_fn(image, mask):
        """Apply Albumentations augmentations."""
        def aug(img, msk):
            img = img.numpy().astype(np.uint8)
            msk = msk.numpy().astype(np.uint8)
            
            # Apply augmentations (including normalization)
            augmented = augmentations(image=img, mask=msk)
            
            return augmented['image'], augmented['mask'].astype(np.int32)
        
        # Use tf.py_function to apply numpy-based augmentations
        aug_img, aug_mask = tf.py_function(
            aug, [image, mask], [tf.float32, tf.int32]
        )
        
        # Set shapes (required after py_function)
        aug_img.set_shape([config.input_shape[0], config.input_shape[1], 3])
        aug_mask.set_shape([config.input_shape[0], config.input_shape[1]])
        
        # One-hot encode the mask for multi-class segmentation
        aug_mask = tf.one_hot(aug_mask, config.num_classes)
        
        return aug_img, aug_mask
    
    # Build pipeline
    dataset = dataset.map(preprocess_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
    
    if training:
        dataset = dataset.shuffle(buffer_size=1000)
    
    dataset = dataset.batch(config.batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

# Create augmentation pipelines
train_aug = create_augmentations(training=True, backbone_name=config.backbone)
val_aug = create_augmentations(training=False, backbone_name=config.backbone)

# Create data pipelines
print("\nCreating data pipelines...")
train_dataset = create_data_pipeline(ds_train, config, train_aug, training=True)
val_dataset = create_data_pipeline(ds_val, config, val_aug, training=False)
test_dataset = create_data_pipeline(ds_test, config, val_aug, training=False)
print("✓ Data pipelines ready")

Creating data pipelines... ✓ Data pipelines ready

👀 Visualize Data Samples 🔗

Preview augmented training samples:

def denormalize_image(img, backbone_name='resnet50'):
    """Properly denormalize image for visualization."""
    # Get normalization parameters
    if backbone_name in ['resnet34', 'resnet50', 'vgg16']:
        # Standard ImageNet normalization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
    elif backbone_name == 'mobilenetv2':
        # MobileNetV2 normalizes to [-1, 1]
        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.5, 0.5, 0.5])
    else:
        # Default ImageNet normalization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
    
    # Denormalize: x = (x_norm * std) + mean
    img_denorm = (img * std) + mean
    # Clip to [0, 1] range for valid image
    img_denorm = np.clip(img_denorm, 0, 1)
    return img_denorm

def visualize_samples(dataset, num_samples=3):
    """Visualize images and masks from dataset."""
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples * 3))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    # Get a batch of data
    for images, masks in dataset.take(1):
        for i in range(min(num_samples, images.shape[0])):
            # Properly denormalize image for visualization
            img = images[i].numpy()
            img = denormalize_image(img, config.backbone)
            
            # Convert one-hot mask to class indices
            mask = tf.argmax(masks[i], axis=-1).numpy()
            
            # Create colored mask for better visualization
            colored_mask = np.zeros((*mask.shape, 3))
            colored_mask[mask == 0] = [0, 0, 0]     # Background: black
            colored_mask[mask == 1] = [0, 1, 0]     # Pet: green
            colored_mask[mask == 2] = [1, 0, 0]     # Border: red
            
            # Plot image
            axes[i, 0].imshow(img)
            axes[i, 0].set_title("Input Image")
            axes[i, 0].axis('off')
            
            # Plot mask
            axes[i, 1].imshow(colored_mask)
            axes[i, 1].set_title("Segmentation Mask")
            axes[i, 1].axis('off')
            
            # Plot overlay
            axes[i, 2].imshow(img)
            axes[i, 2].imshow(colored_mask, alpha=0.5)
            axes[i, 2].set_title("Overlay")
            axes[i, 2].axis('off')
    
    plt.suptitle("Training Samples with Albumentations Augmentations", fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualize training samples
print("Training samples with augmentations:")
visualize_samples(train_dataset, num_samples=3)

Training samples with augmentations:

png

🚀 Build and Compile Model 🔗

# Build the model
print(f"Building &#123;config.backbone.upper()&#125; U-Net...\n")
model = build_unet_model(config.backbone, config.input_shape, config.num_classes)

# Compile the model
model.compile(
    optimizer=optimizers.Adam(learning_rate=config.learning_rate),
    loss=combined_loss,
    metrics=[
        dice_coefficient,
        MeanIoU(num_classes=config.num_classes, name='mean_iou'),
        'categorical_accuracy'
    ]
)

print("\n📊 Model Summary:")
print(f"  Architecture: U-Net with &#123;config.backbone.upper()&#125; encoder")
print(f"  Total parameters: &#123;model.count_params():,&#125;")
print(f"  Loss: Combined (Dice + Categorical Crossentropy)")
print(f"  Optimizer: Adam (lr=&#123;config.learning_rate&#125;)")

Building RESNET50 U-Net...

Loading ResNet50 pretrained on ImageNet... Model created: 33,387,043 parameters

📊 Model Summary: Architecture: U-Net with RESNET50 encoder Total parameters: 33,387,043 Loss: Combined (Dice + Categorical Crossentropy) Optimizer: Adam (lr=0.0005)

🏋️ Training 🔗

Train the model with callbacks for monitoring and early stopping:

# Setup callbacks
os.makedirs(config.checkpoint_dir, exist_ok=True)

checkpoint = callbacks.ModelCheckpoint(
    filepath=os.path.join(config.checkpoint_dir, f'best_&#123;config.backbone&#125;_unet.keras'),
    monitor='val_dice_coefficient',
    mode='max',
    save_best_only=True,
    verbose=0
)

early_stop = callbacks.EarlyStopping(
    monitor='val_dice_coefficient',
    mode='max',
    patience=7,
    restore_best_weights=True,
    verbose=0
)

reduce_lr = callbacks.ReduceLROnPlateau(
    monitor='val_dice_coefficient',
    mode='max',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=0
)

# Create tqdm callback for progress bars
tqdm_callback = TqdmCallback(verbose=2, leave=True)

# Train the model
print(f"\n🚂 Starting training for &#123;config.epochs&#125; epochs...")
print(f"   Backbone: &#123;config.backbone.upper()&#125; (pretrained)")
print(f"   Batch size: &#123;config.batch_size&#125;")
print(f"   Learning rate: &#123;config.learning_rate&#125;\n")

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=config.epochs,
    callbacks=[checkpoint, early_stop, reduce_lr, tqdm_callback],
    verbose=0  # Disable default progress bar (using tqdm instead)
)

print("\n✅ Training complete!")

# Print best metrics
best_dice = max(history.history['val_dice_coefficient'])
best_iou = max(history.history['val_mean_iou'])
print(f"\n📊 Best validation metrics:")
print(f"   Dice coefficient: &#123;best_dice:.4f&#125;")
print(f"   Mean IoU: &#123;best_iou:.4f&#125;")

 

🚂 Starting training for 10 epochs... Backbone: RESNET50 (pretrained) Batch size: 8 Learning rate: 0.0005



📈 Training History 🔗

Visualize training metrics over epochs:

# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history.history['loss'], label='Train', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Val', linewidth=2)
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Dice Coefficient
axes[1].plot(history.history['dice_coefficient'], label='Train', linewidth=2)
axes[1].plot(history.history['val_dice_coefficient'], label='Val', linewidth=2)
axes[1].set_title('Dice Coefficient')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Dice')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Mean IoU
axes[2].plot(history.history['mean_iou'], label='Train', linewidth=2)
axes[2].plot(history.history['val_mean_iou'], label='Val', linewidth=2)
axes[2].set_title('Mean IoU')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('IoU')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle(f'Training History - &#123;config.backbone.upper()&#125; U-Net', fontsize=14)
plt.tight_layout()
plt.show()

# Print final metrics
print(f"\nFinal Training Metrics (Epoch &#123;len(history.history['loss'])&#125;):")
print(f"  Train Dice: &#123;history.history['dice_coefficient'][-1]:.4f&#125;")
print(f"  Val Dice: &#123;history.history['val_dice_coefficient'][-1]:.4f&#125;")
print(f"  Train IoU: &#123;history.history['mean_iou'][-1]:.4f&#125;")
print(f"  Val IoU: &#123;history.history['val_mean_iou'][-1]:.4f&#125;")

🎯 Model Evaluation 🔗

Evaluate model on test set:

# Evaluate on test set
print("🔍 Evaluating on test set...")

test_results = model.evaluate(
    test_dataset,
    verbose=0,
    callbacks=[TqdmCallback(verbose=1, leave=True)]
)

print("\n📊 Test Set Results:")
print(f"   Loss: &#123;test_results[0]:.4f&#125;")
print(f"   Dice Coefficient: &#123;test_results[1]:.4f&#125;")
print(f"   Mean IoU: &#123;test_results[2]:.4f&#125;")
print(f"   Categorical Accuracy: &#123;test_results[3]:.4f&#125;")

# Performance interpretation
if test_results[1] > 0.8:
    print("\n✨ Excellent performance! The pretrained backbone works great!")
elif test_results[1] > 0.7:
    print("\n✅ Good performance! Transfer learning is effective!")
elif test_results[1] > 0.6:
    print("\n📈 Decent results. Consider training for more epochs.")
else:
    print("\n⚠️ Room for improvement. Try adjusting hyperparameters or augmentations.")

🔍 Visualize Predictions 🔗

Compare model predictions with ground truth:

def visualize_predictions(model, dataset, num_samples=3):
    """Visualize model predictions vs ground truth."""
    fig, axes = plt.subplots(num_samples, 4, figsize=(15, num_samples * 3.5))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    # Get predictions
    for images, masks_true in dataset.take(1):
        masks_pred = model.predict(images, verbose=0)
        
        for i in range(min(num_samples, images.shape[0])):
            # Denormalize image
            img = images[i].numpy()
            img = denormalize_image(img, config.backbone)
            
            # Convert masks to class indices
            mask_true = tf.argmax(masks_true[i], axis=-1).numpy()
            mask_pred = tf.argmax(masks_pred[i], axis=-1).numpy()
            
            # Create colored masks
            def create_colored_mask(mask):
                colored = np.zeros((*mask.shape, 3))
                colored[mask == 0] = [0, 0, 0]  # Background: black
                colored[mask == 1] = [0, 1, 0]  # Pet: green
                colored[mask == 2] = [1, 0, 0]  # Border: red
                return colored
            
            colored_true = create_colored_mask(mask_true)
            colored_pred = create_colored_mask(mask_pred)
            
            # Calculate accuracy for this sample
            accuracy = np.mean(mask_true == mask_pred)
            
            # Plot
            axes[i, 0].imshow(img)
            axes[i, 0].set_title("Input Image")
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(colored_true)
            axes[i, 1].set_title("Ground Truth")
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(colored_pred)
            axes[i, 2].set_title(f"Prediction (Acc: &#123;accuracy:.2%&#125;)")
            axes[i, 2].axis('off')
            
            axes[i, 3].imshow(img)
            axes[i, 3].imshow(colored_pred, alpha=0.5)
            axes[i, 3].set_title("Overlay")
            axes[i, 3].axis('off')
    
    plt.suptitle(f'Model Predictions - &#123;config.backbone.upper()&#125; U-Net', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualize predictions
print("Test set predictions:")
visualize_predictions(model, test_dataset, num_samples=3)

💾 Save Model 🔗

Save the trained model for future use:

# Save the model
model_path = os.path.join(config.checkpoint_dir, f'final_&#123;config.backbone&#125;_unet.keras')
model.save(model_path)
print(f"✅ Model saved to: &#123;model_path&#125;")
print(f"   Size: &#123;os.path.getsize(model_path) / (1024*1024):.2f&#125; MB")

# To load the model later:
print("\nTo load this model later, use:")
print(f"model = keras.models.load_model('&#123;model_path&#125;', custom_objects=&#123;&#123;")
print("    'dice_coefficient': dice_coefficient,")
print("    'combined_loss': combined_loss,")
print("    'MeanIoU': MeanIoU")
print("&#125;)")

🎉 Conclusion 🔗

What We Achieved: 🔗

  • ✅ Built U-Net with pretrained ImageNet encoder
  • ✅ Applied segmentation-optimized augmentations with Albumentations
  • ✅ Trained on Oxford-IIIT Pet dataset
  • ✅ Achieved good segmentation performance

Key Takeaways: 🔗

  1. Pretrained models significantly improve performance and convergence speed
  2. Proper augmentations are crucial for segmentation (avoid aggressive distortions)
  3. Combined loss (Dice + CrossEntropy) handles class imbalance well
  4. Albumentations provides powerful and efficient augmentation pipelines

Next Steps: 🔗

  • Try different backbones (ResNet34, MobileNetV2, VGG16)
  • Experiment with learning rates and batch sizes
  • Train for more epochs for better performance
  • Apply to your own segmentation dataset

Tips for Your Own Data: 🔗

  1. Adjust num_classes to match your dataset
  2. Modify augmentations based on your domain (medical, satellite, etc.)
  3. Use class weights if you have severe class imbalance
  4. Consider using larger input sizes for fine details (384x384, 512x512)