Source code for otx.algo.anomaly.padim

"""OTX Padim model."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# TODO(someone): Revisit mypy errors after OTXLitModule deprecation and anomaly refactoring
# mypy: ignore-errors

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from anomalib.models.image import Padim as AnomalibPadim

from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
    from otx.core.types.label import LabelInfoTypes


[docs] class Padim(AnomalyMixin, AnomalibPadim, OTXAnomaly): """OTX Padim model. Args: backbone (str, optional): Feature extractor backbone. Defaults to "resnet18". layers (list[str], optional): Feature extractor layers. Defaults to ["layer1", "layer2", "layer3"]. pre_trained (bool, optional): Pretrained backbone. Defaults to True. n_features (int | None, optional): Number of features. Defaults to None. task (Literal[ OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION ], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION. input_size (tuple[int, int], optional): Model input size in the order of height and width. Defaults to (256, 256) """ def __init__( self, label_info: LabelInfoTypes = AnomalyLabelInfo(), backbone: str = "resnet18", layers: list[str] = ["layer1", "layer2", "layer3"], # noqa: B006 pre_trained: bool = True, n_features: int | None = None, task: Literal[ OTXTaskType.ANOMALY, OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION, ] = OTXTaskType.ANOMALY_CLASSIFICATION, input_size: tuple[int, int] = (256, 256), ) -> None: self.input_size = input_size self.task = OTXTaskType(task) super().__init__( backbone=backbone, layers=layers, pre_trained=pre_trained, n_features=n_features, )