otx.algo.classification.hlabel_models#
hlabel classification models package.
Classes
|
EfficientNet Model for hierarchical label classification task. |
|
MobileNetV3 Model for hierarchical label classification task. |
|
Timm Model for hierarchical label classification task. |
|
TVModelForHLabelCls class represents a Torchvision model for hierarchical label classification. |
|
VisionTransformerForHLabelCls is a model designed for hierarchical label classification using ViT architecture. |
- class otx.algo.classification.hlabel_models.EfficientNetHLabelCls(label_info: HLabelInfo, data_input_params: DataInputParams, model_name: str = 'efficientnet_b0', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#
Bases:
OTXHlabelClsModel
EfficientNet Model for hierarchical label classification task.
Initialize the base model with the given parameters.
- Parameters:
label_info (LabelInfoTypes) – Information about the labels used in the model.
data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.
model_name (str, optional) – Name of the model. Defaults to “OTXModel”.
optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.
torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.
tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
- Returns:
None
- class otx.algo.classification.hlabel_models.MobileNetV3HLabelCls(label_info: HLabelInfo, data_input_params: DataInputParams, model_name: str = 'mobilenetv3_large', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#
Bases:
OTXHlabelClsModel
MobileNetV3 Model for hierarchical label classification task.
Initialize the base model with the given parameters.
- Parameters:
label_info (LabelInfoTypes) – Information about the labels used in the model.
data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.
model_name (str, optional) – Name of the model. Defaults to “OTXModel”.
optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.
torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.
tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
- Returns:
None
- class otx.algo.classification.hlabel_models.TVModelHLabelCls(label_info: HLabelInfo, data_input_params: DataInputParams, model_name: str = 'efficientnet_v2_s', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#
Bases:
OTXHlabelClsModel
TVModelForHLabelCls class represents a Torchvision model for hierarchical label classification.
- Parameters:
label_info (HLabelInfo) – Information about the hierarchical labels.
backbone (TVModelType) – The type of Torchvision backbone model.
pretrained (bool, optional) – Whether to use pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional) – The optimizer callable. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable. Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional) – The metric callable. Defaults to HLabelClsMetricCallble.
torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.
input_size (tuple[int, int], optional) – The input size of the images. Defaults to (224, 224).
Initialize the base model with the given parameters.
- Parameters:
label_info (LabelInfoTypes) – Information about the labels used in the model.
data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.
model_name (str, optional) – Name of the model. Defaults to “OTXModel”.
optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.
torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.
tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
- Returns:
None
- class otx.algo.classification.hlabel_models.TimmModelHLabelCls(label_info: HLabelInfo, data_input_params: DataInputParams, model_name: str = 'tf_efficientnetv2_s.in21k', optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#
Bases:
OTXHlabelClsModel
Timm Model for hierarchical label classification task.
- Parameters:
label_info (HLabelInfo) – The label information for the classification task.
model_name (str) – The name of the model. You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional) – Model input size in the order of height and width. Defaults to (224, 224).
pretrained (bool, optional) – Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional) – The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional) – The learning rate scheduler callable.
metric (MetricCallable, optional) – The metric callable for evaluating the model. Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional) – Whether to compile the model using TorchScript. Defaults to False.
Initialize the base model with the given parameters.
- Parameters:
label_info (LabelInfoTypes) – Information about the labels used in the model.
data_input_params (DataInputParams) – Parameters of the input data such as input size, mean, and std.
model_name (str, optional) – Name of the model. Defaults to “OTXModel”.
optimizer (OptimizerCallable, optional) – Callable for the optimizer. Defaults to DefaultOptimizerCallable.
scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler. Defaults to DefaultSchedulerCallable.
metric (MetricCallable, optional) – Callable for the metric. Defaults to NullMetricCallable.
torch_compile (bool, optional) – Flag to indicate if torch.compile should be used. Defaults to False.
tile_config (TileConfig, optional) – Configuration for tiling. Defaults to TileConfig(enable_tiler=False).
- Returns:
None
- class otx.algo.classification.hlabel_models.VisionTransformerHLabelCls(label_info: HLabelInfo, data_input_params: DataInputParams, model_name: str = 'vit-tiny', lora: bool = False, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False)[source]#
Bases:
ForwardExplainMixInForViT
,OTXHlabelClsModel
VisionTransformerForHLabelCls is a model designed for hierarchical label classification using ViT architecture.
- Parameters:
label_info (HLabelInfo) – Information about the hierarchical labels.
lora (bool) – Whether to use LoRA (Low-Rank Adaptation) for the model.
model_name (str) – Name of the Vision Transformer model to use.
data_input_params (DataInputParams) – Parameters for data input.
optimizer (OptimizerCallable) – Callable for the optimizer.
scheduler (LRSchedulerCallable | LRSchedulerListCallable) – Callable for the learning rate scheduler.
metric (MetricCallable) – Callable for the metric.
torch_compile (bool) – Whether to use torch.compile for the model.