otx.core.model.diffusion#

Class definition for diffusion model entity used in OTX.

Classes

OTXDiffusionModel(optimizer, scheduler, ...)

OTX Diffusion model.

class otx.core.model.diffusion.OTXDiffusionModel(optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _diffusion_metric_callable>, label_info: int = 0, **kwargs)[source]#

Bases: OTXModel[DiffusionBatchDataEntity, DiffusionBatchPredEntity]

OTX Diffusion model.

configure_metric() None[source]#

Configure the metric.

on_test_start() None[source]#

Called at the beginning of testing.

Don’t configure the metric here. Do it in constructor.

on_validation_start() None[source]#

Called at the beginning of validation.

Don’t configure the metric here. Do it in constructor.

test_step(batch: DiffusionBatchDataEntity, batch_idx: int) None[source]#

Perform a single test step on a batch of data from the test set.

Parameters:
  • batch – A batch of data (a tuple) containing the input tensor of images and target labels.

  • batch_idx – The index of the current batch.

training_step(batch: DiffusionBatchDataEntity, batch_idx: int) torch.Tensor[source]#

Step for model training.