otx.algo.classification.timm_model#

TIMM wrapper model class for OTX.

Classes

TimmModelForHLabelCls(label_info, ...)

Timm Model for hierarchical label classification task.

TimmModelForMulticlassCls(label_info, ...)

TimmModel for multi-class classification task.

TimmModelForMultilabelCls(label_info, ...)

TimmModel for multi-label classification task.

class otx.algo.classification.timm_model.TimmModelForHLabelCls(label_info: HLabelInfo, model_name: str, input_size: tuple[int, int] = (224, 224), pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#

Bases: OTXHlabelClsModel

Timm Model for hierarchical label classification task.

Parameters:
  • label_info (HLabelInfo) – The label information for the classification task.

  • model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).

  • pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.

  • optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.

  • metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to HLabelClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Example

  1. API
    >>> model = TimmModelForHLabelCls(
    ...     model_name="tf_efficientnetv2_s.in21k",
    ...     label_info=<h-label-info>,
    ... )
    
  2. CLI
    >>> otx train             ... --model otx.algo.classification.timm_model.TimmModelForHLabelCls             ... --model.model_name tf_efficientnetv2_s.in21k
    
forward_explain(inputs: HlabelClsBatchDataEntity) HlabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.timm_model.TimmModelForMulticlassCls(label_info: LabelInfoTypes, model_name: str, input_size: tuple[int, int] = (224, 224), pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#

Bases: OTXMulticlassClsModel

TimmModel for multi-class classification task.

Parameters:
  • label_info (LabelInfoTypes) – The label information for the classification task.

  • model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).

  • pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.

  • optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.

  • metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to MultiClassClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

  • train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional) – The training type.

Example

  1. API
    >>> model = TimmModelForMulticlassCls(
    ...     model_name="tf_efficientnetv2_s.in21k",
    ...     label_info=<Number-of-classes>,
    ... )
    
  2. CLI
    >>> otx train             ... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls             ... --model.model_name tf_efficientnetv2_s.in21k
    
forward_explain(inputs: MulticlassClsBatchDataEntity) MulticlassClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.timm_model.TimmModelForMultilabelCls(label_info: LabelInfoTypes, model_name: str, input_size: tuple[int, int] = (224, 224), pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False)[source]#

Bases: OTXMultilabelClsModel

TimmModel for multi-label classification task.

Parameters:
  • label_info (LabelInfoTypes) – The label information for the classification task.

  • model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().

  • input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).

  • pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.

  • optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.

  • scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.

  • metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to MultiLabelClsMetricCallable.

  • torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.

Example

  1. API
    >>> model = TimmModelForMultilabelCls(
    ...     model_name="tf_efficientnetv2_s.in21k",
    ...     label_info=<Number-of-classes>,
    ... )
    
  2. CLI
    >>> otx train             ... --model otx.algo.classification.timm_model.TimmModelForMultilabelCls             ... --model.model_name tf_efficientnetv2_s.in21k
    
forward_explain(inputs: MultilabelClsBatchDataEntity) MultilabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.