Source code for otx.core.model.diffusion
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Class definition for diffusion model entity used in OTX."""
from __future__ import annotations
from typing import TYPE_CHECKING
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.diffusion import DiffusionBatchDataEntity, DiffusionBatchPredEntity
from otx.core.metrics.diffusion import DiffusionMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.export import TaskLevelExportParameters
if TYPE_CHECKING:
import torch
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from torch import nn
from otx.core.metrics import MetricCallable
[docs]
class OTXDiffusionModel(OTXModel[DiffusionBatchDataEntity, DiffusionBatchPredEntity]):
"""OTX Diffusion model."""
def __init__(
self,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: (LRSchedulerCallable | LRSchedulerListCallable) = DefaultSchedulerCallable,
metric: MetricCallable = DiffusionMetricCallable,
label_info: int = 0,
**kwargs,
):
super().__init__(
label_info=label_info,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
**kwargs,
)
self.configure_metric()
def _create_model(self) -> nn.Module:
raise NotImplementedError
[docs]
def training_step(self, batch: DiffusionBatchDataEntity, batch_idx: int) -> torch.Tensor:
"""Step for model training."""
train_loss = super().training_step(batch, batch_idx)
if self.current_epoch == 0:
self.metric.update(batch.images, real=True)
return train_loss
def _customize_outputs(
self,
outputs: torch.Tensor,
inputs: DiffusionBatchDataEntity,
) -> DiffusionBatchPredEntity | OTXBatchLossEntity:
return OTXBatchLossEntity(loss=outputs)
[docs]
def on_validation_start(self) -> None:
"""Called at the beginning of validation.
Don't configure the metric here. Do it in constructor.
"""
[docs]
def on_test_start(self) -> None:
"""Called at the beginning of testing.
Don't configure the metric here. Do it in constructor.
"""
[docs]
def test_step(self, batch: DiffusionBatchDataEntity, batch_idx: int) -> None:
"""Perform a single test step on a batch of data from the test set.
:param batch: A batch of data (a tuple) containing the input tensor of images and target
labels.
:param batch_idx: The index of the current batch.
"""
self.validation_step(batch, batch_idx)
@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
return super()._export_parameters.wrap(model_type="unet", task_type="diffusion")