otx.algo.classification.efficientnet#

EfficientNet-B0 model implementation.

Classes

EfficientNetForHLabelCls(label_info, ...)

EfficientNetB0 Model for hierarchical label classification task.

EfficientNetForMulticlassCls(label_info, ...)

EfficientNet Model for multi-class classification task.

EfficientNetForMultilabelCls(label_info, ...)

EfficientNet Model for multi-label classification task.

class otx.algo.classification.efficientnet.EfficientNetForHLabelCls(label_info: HLabelInfo, version: EFFICIENTNET_VERSION = 'b0', pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: OTXHlabelClsModel

EfficientNetB0 Model for hierarchical label classification task.

forward_explain(inputs: HlabelClsBatchDataEntity) HlabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.efficientnet.EfficientNetForMulticlassCls(label_info: LabelInfoTypes, version: EFFICIENTNET_VERSION = 'b0', pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224), train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#

Bases: OTXMulticlassClsModel

EfficientNet Model for multi-class classification task.

forward_explain(inputs: MulticlassClsBatchDataEntity) MulticlassClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.efficientnet.EfficientNetForMultilabelCls(label_info: LabelInfoTypes, version: EFFICIENTNET_VERSION = 'b0', pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: OTXMultilabelClsModel

EfficientNet Model for multi-label classification task.

forward_explain(inputs: MultilabelClsBatchDataEntity) MultilabelClsBatchPredEntity[source]#

Model forward explain function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.