otx.core.model.keypoint_detection#

Class definition for keypoint detection model entity used in OTX.

Classes

OTXKeypointDetectionModel(label_info, ...)

Base class for the detection models used in OTX.

class otx.core.model.keypoint_detection.OTXKeypointDetectionModel(label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _pck_measure_callable>, torch_compile: bool = False)[source]#

Bases: OTXModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]

Base class for the detection models used in OTX.

configure_metric() None[source]#

Configure the metric.

forward_for_tracing(image: Tensor) Tensor | tuple[Tensor][source]#

Model forward function used for the model tracing during model exportation.

get_classification_layers(prefix: str = 'model.') dict[str, dict[str, int]][source]#

Get final classification layer information for incremental learning case.

get_dummy_input(batch_size: int = 1) KeypointDetBatchDataEntity[source]#

Generates a dummy input, suitable for launching forward() on it.

Parameters:

batch_size (int, optional) – number of elements in a dummy input sequence. Defaults to 1.

Returns:

An entity containing randomly generated inference data.

Return type:

KeypointDetBatchDataEntity