"""GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training.
https://arxiv.org/abs/1805.06725
"""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import Callback, EarlyStopping
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from torch import Tensor, optim
from anomalib.models.components import AnomalyModule
from anomalib.models.ganomaly.loss import DiscriminatorLoss, GeneratorLoss
from .torch_model import GanomalyModel
logger = logging.getLogger(__name__)
[docs]class Ganomaly(AnomalyModule):
"""PL Lightning Module for the GANomaly Algorithm.
Args:
batch_size (int): Batch size.
input_size (tuple[int, int]): Input dimension.
n_features (int): Number of features layers in the CNNs.
latent_vec_size (int): Size of autoencoder latent vector.
extra_layers (int, optional): Number of extra layers for encoder/decoder. Defaults to 0.
add_final_conv_layer (bool, optional): Add convolution layer at the end. Defaults to True.
wadv (int, optional): Weight for adversarial loss. Defaults to 1.
wcon (int, optional): Image regeneration weight. Defaults to 50.
wenc (int, optional): Latent vector encoder weight. Defaults to 1.
"""
def __init__(
self,
batch_size: int,
input_size: tuple[int, int],
n_features: int,
latent_vec_size: int,
extra_layers: int = 0,
add_final_conv_layer: bool = True,
wadv: int = 1,
wcon: int = 50,
wenc: int = 1,
lr: float = 0.0002,
beta1: float = 0.5,
beta2: float = 0.999,
) -> None:
super().__init__()
self.model: GanomalyModel = GanomalyModel(
input_size=input_size,
num_input_channels=3,
n_features=n_features,
latent_vec_size=latent_vec_size,
extra_layers=extra_layers,
add_final_conv_layer=add_final_conv_layer,
)
self.real_label = torch.ones(size=(batch_size,), dtype=torch.float32)
self.fake_label = torch.zeros(size=(batch_size,), dtype=torch.float32)
self.min_scores: Tensor = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
self.max_scores: Tensor = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable
self.generator_loss = GeneratorLoss(wadv, wcon, wenc)
self.discriminator_loss = DiscriminatorLoss()
# TODO: LR should be part of optimizer in config.yaml! Since ganomaly has custom
# optimizer this is to be addressed later.
self.learning_rate = lr
self.beta1 = beta1
self.beta2 = beta2
def _reset_min_max(self) -> None:
"""Resets min_max scores."""
self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
self.max_scores = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable
[docs] def training_step(
self, batch: dict[str, str | Tensor], batch_idx: int, optimizer_idx: int
) -> STEP_OUTPUT: # pylint: disable=arguments-differ
"""Training step.
Args:
batch (dict[str, str | Tensor]): Input batch containing images.
batch_idx (int): Batch index.
optimizer_idx (int): Optimizer which is being called for current training step.
Returns:
STEP_OUTPUT: Loss
"""
del batch_idx # `batch_idx` variables is not used.
# forward pass
padded, fake, latent_i, latent_o = self.model(batch["image"])
pred_real, _ = self.model.discriminator(padded)
if optimizer_idx == 0: # Discriminator
pred_fake, _ = self.model.discriminator(fake.detach())
loss = self.discriminator_loss(pred_real, pred_fake)
else: # Generator
pred_fake, _ = self.model.discriminator(fake)
loss = self.generator_loss(latent_i, latent_o, padded, fake, pred_real, pred_fake)
self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True)
return {"loss": loss}
[docs] def on_validation_start(self) -> None:
"""Reset min and max values for current validation epoch."""
self._reset_min_max()
return super().on_validation_start()
[docs] def validation_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT:
"""Update min and max scores from the current step.
Args:
batch (dict[str, str | Tensor]): Predicted difference between z and z_hat.
Returns:
(STEP_OUTPUT): Output predictions.
"""
batch["pred_scores"] = self.model(batch["image"])
self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"]))
self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"]))
return batch
[docs] def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> EPOCH_OUTPUT:
"""Normalize outputs based on min/max values."""
logger.info("Normalizing validation outputs based on min/max values.")
for prediction in outputs:
prediction["pred_scores"] = self._normalize(prediction["pred_scores"])
super().validation_epoch_end(outputs)
return outputs
[docs] def on_test_start(self) -> None:
"""Reset min max values before test batch starts."""
self._reset_min_max()
return super().on_test_start()
[docs] def test_step(self, batch: dict[str, str | Tensor], batch_idx: int, *args, **kwargs) -> STEP_OUTPUT:
"""Update min and max scores from the current step."""
super().test_step(batch, batch_idx)
self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"]))
self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"]))
return batch
[docs] def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> EPOCH_OUTPUT:
"""Normalize outputs based on min/max values."""
logger.info("Normalizing test outputs based on min/max values.")
for prediction in outputs:
prediction["pred_scores"] = self._normalize(prediction["pred_scores"])
super().test_epoch_end(outputs)
return outputs
def _normalize(self, scores: Tensor) -> Tensor:
"""Normalize the scores based on min/max of entire dataset.
Args:
scores (Tensor): Un-normalized scores.
Returns:
Tensor: Normalized scores.
"""
scores = (scores - self.min_scores.to(scores.device)) / (
self.max_scores.to(scores.device) - self.min_scores.to(scores.device)
)
return scores
[docs]class GanomalyLightning(Ganomaly):
"""PL Lightning Module for the GANomaly Algorithm.
Args:
hparams (DictConfig | ListConfig): Model params
"""
def __init__(self, hparams: DictConfig | ListConfig) -> None:
super().__init__(
batch_size=hparams.dataset.train_batch_size,
input_size=hparams.model.input_size,
n_features=hparams.model.n_features,
latent_vec_size=hparams.model.latent_vec_size,
extra_layers=hparams.model.extra_layers,
add_final_conv_layer=hparams.model.add_final_conv,
wadv=hparams.model.wadv,
wcon=hparams.model.wcon,
wenc=hparams.model.wenc,
lr=hparams.model.lr,
beta1=hparams.model.beta1,
beta2=hparams.model.beta2,
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)