How to use a custom classification or semantic segmentation model¶
By default AutoAlbument uses
pytorch-image-models for classification and
segmentation_models.pytorch for semantic segmentation. You can use any model from these packages by providing an appropriate model name.
However, you can also use a custom model with AutoAlbument. To do so, you need to define a Discriminator model. This Discriminator model should have two outputs.
The first output should provide a prediction for a classification or semantic segmentation task. For classification, it should output a tensor with a shape
[batch_size, num_classes]with logits. For semantic segmentation, it should output a tensor with the shape
[batch_size, num_classes, height, width]with logits.
The second (auxiliary) output should return a tensor with the shape
[batch_size]that contains logits for Discriminator's predictions (whether Discriminator thinks that an image wasn't or was augmented).
To create such a model, you need to subclass the
autoalbument.faster_autoaugment.models.BaseDiscriminator class and implement the
forward method. This method should take a batch of images, that is, a tensor with the shape
[batch_size, num_channels, height, width]. It should return a tuple that contains tensors from the two outputs described above.
As an example, take a look at how default classification and semantic segmentation models are defined in AutoAlbument - https://github.com/albumentations-team/autoalbument/blob/master/autoalbument/faster_autoaugment/models.py or explore an example of a custom model for the CIFAR10 dataset.
Next, you need to specify this custom model in
config.yaml, an AutoAlbument config file.
AutoAlbument uses the
instantiate function from Hydra to instantiate an object. You need to set the
_target_ config variable in the
semantic_segmentation_model section, depending on the task. In this config variable, you need to provide a path to a class with the model. This path should be located inside PYTHONPATH, so Hydra could correctly use it. The simplest way is to define your model in a file such as
model.py and place this file in the same directory with
search.yaml because this directory is automatically added to PYTHONPATH. Next, you could define
_target_ such as