otx.core.model.keypoint_detection#
Class definition for keypoint detection model entity used in OTX.
Classes
|
Base class for the detection models used in OTX. |
- class otx.core.model.keypoint_detection.OTXKeypointDetectionModel(label_info: LabelInfoTypes, input_size: tuple[int, int], optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _pck_measure_callable>, torch_compile: bool = False)[source]#
Bases:
OTXModel
[KeypointDetBatchDataEntity
,KeypointDetBatchPredEntity
]Base class for the detection models used in OTX.
- forward_for_tracing(image: Tensor) Tensor | tuple[Tensor] [source]#
Model forward function used for the model tracing during model exportation.
- get_classification_layers(prefix: str = 'model.') dict[str, dict[str, int]] [source]#
Get final classification layer information for incremental learning case.
- get_dummy_input(batch_size: int = 1) KeypointDetBatchDataEntity [source]#
Generates a dummy input, suitable for launching forward() on it.
- Parameters:
batch_size (int, optional) – number of elements in a dummy input sequence. Defaults to 1.
- Returns:
An entity containing randomly generated inference data.
- Return type:
KeypointDetBatchDataEntity