"""Anomaly Classification Task."""
# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import io
from typing import Optional
import torch
from anomalib.models import AnomalyModule, get_model
from anomalib.post_processing import NormalizationMethod, ThresholdMethod
from anomalib.utils.callbacks import (
MetricsConfigurationCallback,
MinMaxNormalizationCallback,
PostProcessingConfigurationCallback,
)
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers.csv_logs import CSVLogger
from otx.algorithms.anomaly.adapters.anomalib.callbacks import IterationTimer, ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.anomaly.adapters.anomalib.plugins.xpu_precision import MixedPrecisionXPUPlugin
from otx.algorithms.common.utils.utils import is_xpu_available
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import ModelEntity
from otx.api.entities.train_parameters import TrainParameters
from otx.api.usecases.tasks.interfaces.training_interface import ITrainingTask
from otx.utils.logger import get_logger
from .inference import InferenceTask
logger = get_logger()
[docs]
class TrainingTask(InferenceTask, ITrainingTask):
"""Base Anomaly Task."""
[docs]
def train(
self,
dataset: DatasetEntity,
output_model: ModelEntity,
train_parameters: TrainParameters,
seed: Optional[int] = None,
deterministic: bool = False,
) -> None:
"""Train the anomaly classification model.
Args:
dataset (DatasetEntity): Input dataset.
output_model (ModelEntity): Output model to save the model weights.
train_parameters (TrainParameters): Training parameters
seed (Optional[int]): Setting seed to a value other than 0
deterministic (bool): Setting PytorchLightning trainer's deterministic flag.
"""
logger.info("Training the model.")
config = self.get_config()
if seed:
logger.info(f"Setting seed to {seed}")
seed_everything(seed, workers=True)
config.trainer.deterministic = "warn" if deterministic else deterministic
logger.info("Training Configs '%s'", config)
datamodule = OTXAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type)
callbacks = [
ProgressCallback(parameters=train_parameters),
MinMaxNormalizationCallback(),
MetricsConfigurationCallback(
task=config.dataset.task,
image_metrics=config.metrics.image,
pixel_metrics=config.metrics.get("pixel"),
),
PostProcessingConfigurationCallback(
normalization_method=NormalizationMethod.MIN_MAX,
threshold_method=ThresholdMethod.ADAPTIVE,
manual_image_threshold=config.metrics.threshold.manual_image,
manual_pixel_threshold=config.metrics.threshold.manual_pixel,
),
IterationTimer(on_step=False),
]
plugins = []
if config.trainer.plugins is not None:
plugins.extend(config.trainer.plugins)
config.trainer.pop("plugins")
if is_xpu_available():
config.trainer.strategy = "xpu_single"
config.trainer.accelerator = "xpu"
if config.trainer.precision == 16:
plugins.append(MixedPrecisionXPUPlugin())
self.trainer = Trainer(
**config.trainer, logger=CSVLogger(self.project_path, name=""), callbacks=callbacks, plugins=plugins
)
self.trainer.fit(model=self.model, datamodule=datamodule)
self.save_model(output_model)
logger.info("Training completed.")
[docs]
def load_model(self, otx_model: Optional[ModelEntity]) -> AnomalyModule:
"""Create and Load Anomalib Module from OTX Model.
This method checks if the task environment has a saved OTX Model,
and creates one. If the OTX model already exists, it returns the
the model with the saved weights.
Args:
otx_model (Optional[ModelEntity]): OTX Model from the
task environment.
Returns:
AnomalyModule: Anomalib
classification or segmentation model with/without weights.
"""
model = get_model(config=self.config)
if otx_model is None:
logger.info(
"No trained model in project yet. Created new model with '%s'",
self.model_name,
)
else:
buffer = io.BytesIO(otx_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))
try:
if model_data["config"]["model"]["backbone"] == self.config["model"]["backbone"]:
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")
else:
logger.info(
"Model backbone does not match. Created new model with '%s'",
self.model_name,
)
except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception
return model