Using Albumentations with Tensorflow¶
Author: Ayushman Buragohain
In [2]:
!pip install -q -U albumentations
!echo "$(pip freeze | grep albumentations) is successfully installed"
albumentations==0.4.6 is successfully installed
[Recommended] Update the version of tensorflow_datasets if you want to use it¶
- We'll we using an example from
tensorflow_datasets
.
In [ ]:
! pip install --upgrade tensorflow_datasets
Run the example¶
In [4]:
# necessary imports
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from functools import partial
from albumentations import (
Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast, HorizontalFlip,
Rotate
)
AUTOTUNE = tf.data.experimental.AUTOTUNE
In [5]:
tfds.__version__
Out[5]:
'3.2.1'
In [6]:
# load in the tf_flowers dataset
data, info= tfds.load(name="tf_flowers", split="train", as_supervised=True, with_info=True)
data
Out[6]:
<PrefetchDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
In [7]:
info
Out[7]:
tfds.core.DatasetInfo( name='tf_flowers', version=3.0.1, description='A large set of images of flowers', homepage='https://www.tensorflow.org/tutorials/load_data/images', features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), total_num_examples=3670, splits={ 'train': 3670, }, supervised_keys=('image', 'label'), citation="""@ONLINE {tfflowers, author = "The TensorFlow Team", title = "Flowers", month = "jan", year = "2019", url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }""", redistribution_info=, )
An Example Pipeline Using tf.image
¶
Process Data¶
In [8]:
def process_image(image, label, img_size):
# cast and normalize image
image = tf.image.convert_image_dtype(image, tf.float32)
# apply simple augmentations
image = tf.image.random_flip_left_right(image)
image = tf.image.resize(image,[img_size, img_size])
return image, label
ds_tf = data.map(partial(process_image, img_size=120), num_parallel_calls=AUTOTUNE).batch(30).prefetch(AUTOTUNE)
ds_tf
Out[8]:
<PrefetchDataset shapes: ((None, 120, 120, 3), (None,)), types: (tf.float32, tf.int64)>
View images from the dataset¶
In [9]:
def view_image(ds):
image, label = next(iter(ds)) # extract 1 batch from the dataset
image = image.numpy()
label = label.numpy()
fig = plt.figure(figsize=(22, 22))
for i in range(20):
ax = fig.add_subplot(4, 5, i+1, xticks=[], yticks=[])
ax.imshow(image[i])
ax.set_title(f"Label: {label[i]}")
In [10]:
view_image(ds_tf)