Skip to content

How to use a custom classification or semantic segmentation modelΒΆ

By default AutoAlbument uses pytorch-image-models for classification and segmentation_models.pytorch for semantic segmentation. You can use any model from these packages by providing an appropriate model name.

However, you can also use a custom model with AutoAlbument. To do so, you need to define a Discriminator model. This Discriminator model should have two outputs.

  • The first output should provide a prediction for a classification or semantic segmentation task. For classification, it should output a tensor with a shape [batch_size, num_classes] with logits. For semantic segmentation, it should output a tensor with the shape [batch_size, num_classes, height, width] with logits.

  • The second (auxiliary) output should return a tensor with the shape [batch_size] that contains logits for Discriminator's predictions (whether Discriminator thinks that an image wasn't or was augmented).

To create such a model, you need to subclass the autoalbument.faster_autoaugment.models.BaseDiscriminator class and implement the forward method. This method should take a batch of images, that is, a tensor with the shape [batch_size, num_channels, height, width]. It should return a tuple that contains tensors from the two outputs described above.

As an example, take a look at how default classification and semantic segmentation models are defined in AutoAlbument - or explore an example of a custom model for the CIFAR10 dataset.

Next, you need to specify this custom model in config.yaml, an AutoAlbument config file. AutoAlbument uses the instantiate function from Hydra to instantiate an object. You need to set the _target_ config variable in the classification_model or semantic_segmentation_model section, depending on the task. In this config variable, you need to provide a path to a class with the model. This path should be located inside PYTHONPATH, so Hydra could correctly use it. The simplest way is to define your model in a file such as and place this file in the same directory with and search.yaml because this directory is automatically added to PYTHONPATH. Next, you could define _target_ such as _target_: model.MyClassificationModel.

Take a look at the CIFAR10 example config that uses a custom model defined in as a starting point for defining a custom model.