otx.algo.classification.timm_model#
TIMM wrapper model class for OTX.
Classes
|
Timm Model for hierarchical label classification task. |
|
TimmModel for multi-class classification task. |
|
TimmModel for multi-label classification task. |
- class otx.algo.classification.timm_model.TimmModelForHLabelCls(label_info: HLabelInfo, model_name: str, input_size: tuple[int, int] = (224, 224), pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#
Bases:
OTXHlabelClsModel
Timm Model for hierarchical label classification task.
- Parameters:
label_info (HLabelInfo) – The label information for the classification task.
model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).
pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.
metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.
Example
- API
>>> model = TimmModelForHLabelCls( ... model_name="tf_efficientnetv2_s.in21k", ... label_info=<h-label-info>, ... )
- CLI
>>> otx train ... --model otx.algo.classification.timm_model.TimmModelForHLabelCls ... --model.model_name tf_efficientnetv2_s.in21k
- forward_explain(inputs: HlabelClsBatchDataEntity) HlabelClsBatchPredEntity [source]#
Model forward explain function.
- class otx.algo.classification.timm_model.TimmModelForMulticlassCls(label_info: LabelInfoTypes, model_name: str, input_size: tuple[int, int] = (224, 224), pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#
Bases:
OTXMulticlassClsModel
TimmModel for multi-class classification task.
- Parameters:
label_info (LabelInfoTypes) – The label information for the classification task.
model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).
pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.
metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.
train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional) – The training type.
Example
- API
>>> model = TimmModelForMulticlassCls( ... model_name="tf_efficientnetv2_s.in21k", ... label_info=<Number-of-classes>, ... )
- CLI
>>> otx train ... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls ... --model.model_name tf_efficientnetv2_s.in21k
- forward_explain(inputs: MulticlassClsBatchDataEntity) MulticlassClsBatchPredEntity [source]#
Model forward explain function.
- class otx.algo.classification.timm_model.TimmModelForMultilabelCls(label_info: LabelInfoTypes, model_name: str, input_size: tuple[int, int] = (224, 224), pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False)[source]#
Bases:
OTXMultilabelClsModel
TimmModel for multi-label classification task.
- Parameters:
label_info (LabelInfoTypes) – The label information for the classification task.
model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).
pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.
metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to MultiLabelClsMetricCallable.
torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.
Example
- API
>>> model = TimmModelForMultilabelCls( ... model_name="tf_efficientnetv2_s.in21k", ... label_info=<Number-of-classes>, ... )
- CLI
>>> otx train ... --model otx.algo.classification.timm_model.TimmModelForMultilabelCls ... --model.model_name tf_efficientnetv2_s.in21k
- forward_explain(inputs: MultilabelClsBatchDataEntity) MultilabelClsBatchPredEntity [source]#
Model forward explain function.