otx.core.model.keypoint_detection#

Class definition for keypoint detection model entity used in OTX.

Classes

OTXKeypointDetectionModel(label_info, ...)

Base class for the detection models used in OTX.

class otx.core.model.keypoint_detection.OTXKeypointDetectionModel(label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _pck_measure_callable>, torch_compile: bool = False)[source]#

Bases: OTXModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]

Base class for the detection models used in OTX.

configure_metric() None[source]#

Configure the metric.

forward_for_tracing(image: Tensor) Tensor | tuple[Tensor][source]#

Model forward function used for the model tracing during model exportation.

get_classification_layers(prefix: str = 'model.') dict[str, dict[str, int]][source]#

Get final classification layer information for incremental learning case.