Source code for otx.algo.anomaly.stfpm
"""OTX STFPM model."""
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# TODO(someone): Revisit mypy errors after OTXLitModule deprecation and anomaly refactoring
# mypy: ignore-errors
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Sequence
from anomalib.models.image.stfpm import Stfpm as AnomalibStfpm
from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType
if TYPE_CHECKING:
from otx.core.types.label import LabelInfoTypes
[docs]
class Stfpm(AnomalyMixin, AnomalibStfpm, OTXAnomaly):
"""OTX STFPM model.
Args:
layers (Sequence[str]): Feature extractor layers.
backbone (str, optional): Feature extractor backbone. Defaults to "resnet18".
task (Literal[
OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION
], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (256, 256)
"""
def __init__(
self,
label_info: LabelInfoTypes = AnomalyLabelInfo(),
layers: Sequence[str] = ["layer1", "layer2", "layer3"],
backbone: str = "resnet18",
task: Literal[
OTXTaskType.ANOMALY,
OTXTaskType.ANOMALY_CLASSIFICATION,
OTXTaskType.ANOMALY_DETECTION,
OTXTaskType.ANOMALY_SEGMENTATION,
] = OTXTaskType.ANOMALY_CLASSIFICATION,
input_size: tuple[int, int] = (256, 256),
**kwargs,
) -> None:
self.input_size = input_size
self.task = OTXTaskType(task)
super().__init__(
backbone=backbone,
layers=layers,
)