Stay updated
News & InsightsKeras + Albumentations: Cats vs Dogs Classification š
Complete training pipeline with automatic device detection (CUDA/MPS/CPU).
# %pip install -q tensorflow keras tensorflow-datasets
# %pip install -q albumentationsx
# %pip install -q opencv-python-headless matplotlib
import platform
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple
# Display plots inline in the notebook
%matplotlib inline
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import layers, models, optimizers, metrics, callbacks
import albumentations as A
print(f"TensorFlow: {tf.__version__}")
print(f"Keras: {keras.__version__}")
print(f"Albumentations: {A.__version__}")
TensorFlow: 2.19.1 Keras: 3.11.3 Albumentations: 2.0.11
def detect_device() -> Tuple[str, int]:
cuda_gpus = tf.config.list_physical_devices("GPU")
if cuda_gpus:
print(f"ā CUDA GPU detected: {len(cuda_gpus)} device(s)")
for gpu in cuda_gpus:
try:
tf.config.experimental.set_memory_growth(gpu, True)
except:
pass
return "CUDA_GPU", 64
if platform.system() == "Darwin" and platform.machine() == "arm64":
print("ā Apple Silicon MPS detected")
return "MPS", 32
print("ā CPU detected")
return "CPU", 16
DEVICE, BATCH_SIZE = detect_device()
print(f"ā Using {DEVICE} with batch size {BATCH_SIZE}")
ā Apple Silicon MPS detected ā Using MPS with batch size 32
class Config:
input_shape = (128, 128, 3)
num_classes = 1
batch_size = BATCH_SIZE
epochs = 10
learning_rate = 1e-3
train_split = "train[:70%]"
val_split = "train[70%:90%]"
test_split = "train[90%:]"
model_path = f"cats_dogs_{DEVICE.lower()}.keras"
config = Config()
print(f"Batch size: {config.batch_size}, Epochs: {config.epochs}")
Batch size: 32, Epochs: 10
Data Augmentation - Using Exact Parameters from Working Script š
def create_augmentation_pipeline(is_training: bool = True) -> A.Compose:
if is_training:
return A.Compose([
A.RandomResizedCrop(size=(128, 128), scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
A.OneOf([
A.GaussNoise(std_range=(0.1, 0.2), p=1),
A.GaussianBlur(blur_limit=(3, 5), p=1),
], p=0.2),
A.CoarseDropout(num_holes_range=(1, 8), hole_height_range=(0.0625, 0.125), hole_width_range=(0.0625, 0.125), p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
else:
return A.Compose([
A.Resize(height=128, width=128),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_transform = create_augmentation_pipeline(True)
val_transform = create_augmentation_pipeline(False)
print("ā Augmentation pipelines created")
ā Augmentation pipelines created
Load and Prepare Dataset š
# Load dataset
print("Downloading Cats vs Dogs dataset...")
(ds_train, ds_val, ds_test), ds_info = tfds.load(
'cats_vs_dogs',
split=[config.train_split, config.val_split, config.test_split],
with_info=True,
as_supervised=True,
shuffle_files=True
)
print(f"ā Dataset loaded: {ds_info.splits['train'].num_examples} examples")
# Apply augmentations
def augment_image(image, label, transform):
def apply(img):
augmented = transform(image=img.numpy())
return augmented['image'].astype(np.float32)
aug_img = tf.py_function(apply, [image], tf.float32)
aug_img.set_shape([128, 128, 3])
return aug_img, label
# Prepare datasets
train_dataset = ds_train.map(
lambda x, y: augment_image(x, y, train_transform),
num_parallel_calls=tf.data.AUTOTUNE
).cache().shuffle(1000).batch(config.batch_size).prefetch(tf.data.AUTOTUNE)
val_dataset = ds_val.map(
lambda x, y: augment_image(x, y, val_transform),
num_parallel_calls=tf.data.AUTOTUNE
).cache().batch(config.batch_size).prefetch(tf.data.AUTOTUNE)
test_dataset = ds_test.map(
lambda x, y: augment_image(x, y, val_transform),
num_parallel_calls=tf.data.AUTOTUNE
).cache().batch(config.batch_size).prefetch(tf.data.AUTOTUNE)
print("ā Data pipeline ready")
Downloading Cats vs Dogs dataset... ā Dataset loaded: 23262 examples ā Data pipeline ready
Build and Train Model š
# Create model
model = models.Sequential([
layers.Input(shape=config.input_shape),
# Block 1
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D(2),
layers.Dropout(0.25),
# Block 2
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D(2),
layers.Dropout(0.25),
# Block 3
layers.Conv2D(128, 3, padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D(2),
layers.Dropout(0.25),
# Block 4
layers.Conv2D(256, 3, padding='same', activation='relu'),
layers.BatchNormalization(),
layers.GlobalAveragePooling2D(),
# Dense layers
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.5),
# Output
layers.Dense(config.num_classes, activation='sigmoid')
])
# Compile
model.compile(
optimizer=optimizers.Adam(learning_rate=config.learning_rate),
loss='binary_crossentropy',
metrics=['accuracy', metrics.AUC(name='auc')]
)
print(f"ā Model created with {model.count_params():,} parameters")
model.summary()
ā Model created with 490,689 parameters
<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential"</span> </pre>
<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā³āāāāāāāāāāāāāāāāāāāāāāāāā³āāāāāāāāāāāāāāāā ā<span style="font-weight: bold"> Layer (type) </span>ā<span style="font-weight: bold"> Output Shape </span>ā<span style="font-weight: bold"> Param # </span>ā ā”āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā© ā conv2d (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">896</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā batch_normalization ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">128</span> ā ā (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā max_pooling2d (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dropout (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā conv2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">18,496</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā batch_normalization_1 ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">256</span> ā ā (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā max_pooling2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dropout_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā conv2d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">73,856</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā batch_normalization_2 ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">512</span> ā ā (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā max_pooling2d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dropout_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā conv2d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">295,168</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā batch_normalization_3 ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">1,024</span> ā ā (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā global_average_pooling2d ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā ā (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalAveragePooling2D</span>) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">65,792</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā batch_normalization_4 ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">1,024</span> ā ā (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dropout_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">32,896</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā batch_normalization_5 ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">512</span> ā ā (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dropout_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">0</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāā⤠ā dense_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) ā (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) ā <span style="color: #00af00; text-decoration-color: #00af00">129</span> ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā“āāāāāāāāāāāāāāāāāāāāāāāāā“āāāāāāāāāāāāāāāā </pre>
<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">490,689</span> (1.87 MB) </pre>
<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">488,961</span> (1.87 MB) </pre>
<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,728</span> (6.75 KB) </pre>
# Train with callbacks
callbacks_list = [
callbacks.ModelCheckpoint(
'best_model.keras',
monitor='val_auc',
mode='max',
save_best_only=True,
verbose=1
),
callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True,
verbose=1
),
callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=3,
min_lr=1e-7,
verbose=1
)
]
print(f"Training on {DEVICE} for {config.epochs} epochs...")
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=config.epochs,
callbacks=callbacks_list,
verbose=1
)
print("ā Training completed!")
Training on MPS for 10 epochs... Epoch 1/10
2025-09-17 18:53:16.505703: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified buffer_size
of 8388608
[1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 122ms/step - accuracy: 0.5301 - auc: 0.5423 - loss: 0.8683 Epoch 1: val_auc improved from None to 0.63414, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m68s[0m 131ms/step - accuracy: 0.5438 - auc: 0.5613 - loss: 0.7959 - val_accuracy: 0.5682 - val_auc: 0.6341 - val_loss: 0.6740 - learning_rate: 0.0010 Epoch 2/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 119ms/step - accuracy: 0.5700 - auc: 0.5999 - loss: 0.6970 Epoch 2: val_auc improved from 0.63414 to 0.73070, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m65s[0m 127ms/step - accuracy: 0.5885 - auc: 0.6244 - loss: 0.6781 - val_accuracy: 0.6682 - val_auc: 0.7307 - val_loss: 0.6142 - learning_rate: 0.0010 Epoch 3/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 127ms/step - accuracy: 0.6304 - auc: 0.6839 - loss: 0.6387 Epoch 3: val_auc did not improve from 0.73070 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m69s[0m 135ms/step - accuracy: 0.6461 - auc: 0.7005 - loss: 0.6275 - val_accuracy: 0.6336 - val_auc: 0.7236 - val_loss: 0.6419 - learning_rate: 0.0010 Epoch 4/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 134ms/step - accuracy: 0.6720 - auc: 0.7341 - loss: 0.6019 Epoch 4: val_auc improved from 0.73070 to 0.77413, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m73s[0m 143ms/step - accuracy: 0.6768 - auc: 0.7429 - loss: 0.5959 - val_accuracy: 0.6884 - val_auc: 0.7741 - val_loss: 0.6048 - learning_rate: 0.0010 Epoch 5/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 131ms/step - accuracy: 0.6963 - auc: 0.7632 - loss: 0.5792 Epoch 5: val_auc improved from 0.77413 to 0.78378, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m72s[0m 141ms/step - accuracy: 0.6987 - auc: 0.7683 - loss: 0.5736 - val_accuracy: 0.6546 - val_auc: 0.7838 - val_loss: 0.6826 - learning_rate: 0.0010 Epoch 6/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 131ms/step - accuracy: 0.7208 - auc: 0.7952 - loss: 0.5488 Epoch 6: val_auc improved from 0.78378 to 0.84497, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m71s[0m 140ms/step - accuracy: 0.7255 - auc: 0.8010 - loss: 0.5413 - val_accuracy: 0.7148 - val_auc: 0.8450 - val_loss: 0.5432 - learning_rate: 0.0010 Epoch 7/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 134ms/step - accuracy: 0.7365 - auc: 0.8157 - loss: 0.5246 Epoch 7: val_auc improved from 0.84497 to 0.85142, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m72s[0m 142ms/step - accuracy: 0.7457 - auc: 0.8273 - loss: 0.5108 - val_accuracy: 0.7501 - val_auc: 0.8514 - val_loss: 0.5091 - learning_rate: 0.0010 Epoch 8/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 129ms/step - accuracy: 0.7755 - auc: 0.8554 - loss: 0.4731 Epoch 8: val_auc improved from 0.85142 to 0.89584, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m70s[0m 138ms/step - accuracy: 0.7790 - auc: 0.8621 - loss: 0.4629 - val_accuracy: 0.8023 - val_auc: 0.8958 - val_loss: 0.4355 - learning_rate: 0.0010 Epoch 9/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 127ms/step - accuracy: 0.8053 - auc: 0.8870 - loss: 0.4234 Epoch 9: val_auc improved from 0.89584 to 0.92848, saving model to best_model.keras [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m69s[0m 135ms/step - accuracy: 0.8101 - auc: 0.8923 - loss: 0.4136 - val_accuracy: 0.8461 - val_auc: 0.9285 - val_loss: 0.3472 - learning_rate: 0.0010 Epoch 10/10 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m0s[0m 126ms/step - accuracy: 0.8309 - auc: 0.9112 - loss: 0.3786 Epoch 10: val_auc did not improve from 0.92848 [1m509/509[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m69s[0m 135ms/step - accuracy: 0.8350 - auc: 0.9143 - loss: 0.3717 - val_accuracy: 0.8199 - val_auc: 0.9255 - val_loss: 0.3945 - learning_rate: 0.0010 Restoring model weights from the end of the best epoch: 9. ā Training completed!
Evaluate and Save Model š
# Evaluate on test set
test_results = model.evaluate(test_dataset, verbose=1, return_dict=True)
print("\nā Test Results:")
for metric, value in test_results.items():
print(f" {metric}: {value:.4f}")
# Save model
model.save(config.model_path)
print(f"\nā Model saved as {config.model_path}")
# Convert to TFLite
try:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
tflite_path = config.model_path.replace('.keras', '.tflite')
with open(tflite_path, 'wb') as f:
f.write(tflite_model)
print(f"ā TFLite model saved as {tflite_path} ({len(tflite_model)/1024:.2f} KB)")
except Exception as e:
print(f"TFLite conversion failed: {e}")
[1m73/73[0m [32māāāāāāāāāāāāāāāāāāāā[0m[37m[0m [1m2s[0m 33ms/step - accuracy: 0.8547 - auc: 0.9369 - loss: 0.3232
ā Test Results: accuracy: 0.8547 auc: 0.9369 loss: 0.3232
ā Model saved as cats_dogs_mps.keras INFO:tensorflow:Assets written to: /var/folders/68/k137nch11m76w1plfrw320r00000gn/T/tmp0otru8h6/assets
INFO:tensorflow:Assets written to: /var/folders/68/k137nch11m76w1plfrw320r00000gn/T/tmp0otru8h6/assets
Saved artifact at '/var/folders/68/k137nch11m76w1plfrw320r00000gn/T/tmp0otru8h6'. The following endpoints are available:
- Endpoint 'serve' args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name='keras_tensor') Output Type: TensorSpec(shape=(None, 1), dtype=tf.float32, name=None) Captures: 4705823808: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583028288: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583034096: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583043776: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583030224: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583033744: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583037440: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583034448: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583040256: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583038848: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583029168: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583035680: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583042016: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583042896: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583277568: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583275280: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583276688: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583275632: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583286544: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583277392: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583284256: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583286368: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583283552: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583286896: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583289008: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583289712: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583275984: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583339760: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583289536: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583340464: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583342224: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583342928: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583344336: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583345744: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583340288: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583344688: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583346800: TensorSpec(shape=(), dtype=tf.resource, name=None) 5583347504: TensorSpec(shape=(), dtype=tf.resource, name=None) ā TFLite model saved as cats_dogs_mps.tflite (501.87 KB)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1758128695.974261 105836395 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1758128695.974270 105836395 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-09-17 19:04:55.974428: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/68/k137nch11m76w1plfrw320r00000gn/T/tmp0otru8h6
2025-09-17 19:04:55.975394: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-09-17 19:04:55.975398: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/68/k137nch11m76w1plfrw320r00000gn/T/tmp0otru8h6
I0000 00:00:1758128695.983613 105836395 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-09-17 19:04:55.985110: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-09-17 19:04:56.043013: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /var/folders/68/k137nch11m76w1plfrw320r00000gn/T/tmp0otru8h6
2025-09-17 19:04:56.059464: I tensorflow/cc/saved_model/loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 85037 microseconds.
2025-09-17 19:04:56.077965: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var MLIR_CRASH_REPRODUCER_DIRECTORY
to enable.
Visualize Results š
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Loss
axes[0].plot(history.history['loss'], label='Train')
axes[0].plot(history.history['val_loss'], label='Validation')
axes[0].set_title('Model Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Accuracy
axes[1].plot(history.history['accuracy'], label='Train')
axes[1].plot(history.history['val_accuracy'], label='Validation')
axes[1].set_title('Model Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nFinal Results:")
print(f" Test Accuracy: {test_results['accuracy']:.3f}")
print(f" Test AUC: {test_results['auc']:.3f}")
print(f" Device Used: {DEVICE}")
print(f" Total Epochs: {len(history.history['loss'])}")
print("\nā Training complete! Model ready for deployment.")

Final Results: Test Accuracy: 0.855 Test AUC: 0.937 Device Used: MPS Total Epochs: 10
ā Training complete! Model ready for deployment.