PyTorch and Albumentations for image classification¶
This example shows how to use Albumentations for image classification. We will use the Cats vs. Docs
dataset. The task will be to detect whether an image contains a cat or a dog.
Import the required libraries¶
In [1]:
from collections import defaultdict
import copy
import random
import os
import shutil
from urllib.request import urlretrieve
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
cudnn.benchmark = True
Define functions to download an archived dataset and unpack it¶
In [2]:
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_url(url, filepath):
directory = os.path.dirname(os.path.abspath(filepath))
os.makedirs(directory, exist_ok=True)
if os.path.exists(filepath):
print("Filepath already exists. Skipping download.")
return
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
t.total = t.n
def extract_archive(filepath):
extract_dir = os.path.dirname(os.path.abspath(filepath))
shutil.unpack_archive(filepath, extract_dir)
Set the root directory for the downloaded dataset¶
In [3]:
dataset_directory = os.path.join(os.environ["HOME"], "datasets/cats-vs-dogs")
Download and extract the Cats vs. Docs
dataset¶
In [4]:
filepath = os.path.join(dataset_directory, "kagglecatsanddogs_3367a.zip")
download_url(
url="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",
filepath=filepath,
)
extract_archive(filepath)
Filepath already exists. Skipping download.
Split files from the dataset into the train and validation sets¶
Some files in the dataset are broken, so we will use only those image files that OpenCV could load correctly. We will use 20000 images for training, 4936 images for validation, and 10 images for testing.
In [5]:
root_directory = os.path.join(dataset_directory, "PetImages")
cat_directory = os.path.join(root_directory, "Cat")
dog_directory = os.path.join(root_directory, "Dog")
cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])
dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])
images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]
correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]
random.seed(42)
random.shuffle(correct_images_filepaths)
train_images_filepaths = correct_images_filepaths[:20000]
val_images_filepaths = correct_images_filepaths[20000:-10]
test_images_filepaths = correct_images_filepaths[-10:]
print(len(train_images_filepaths), len(val_images_filepaths), len(test_images_filepaths))
20000 4936 10
Define a function to visualize images and their labels¶
Let's define a function that will take a list of images' file paths and their labels and visualize them in a grid. Correct labels are colored green, and incorrectly predicted labels are colored red.
In [6]:
def display_image_grid(images_filepaths, predicted_labels=(), cols=5):
rows = len(images_filepaths) // cols
figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
for i, image_filepath in enumerate(images_filepaths):
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
true_label = os.path.normpath(image_filepath).split(os.sep)[-2]
predicted_label = predicted_labels[i] if predicted_labels else true_label
color = "green" if true_label == predicted_label else "red"
ax.ravel()[i].imshow(image)
ax.ravel()[i].set_title(predicted_label, color=color)
ax.ravel()[i].set_axis_off()
plt.tight_layout()
plt.show()
In [7]:
display_image_grid(test_images_filepaths)