otx.algo.classification.vit#

ViT model implementation.

Classes

ForwardExplainMixInForViT()

ViT model which can attach a XAI (Explainable AI) branch.

VisionTransformerForHLabelCls(label_info, ...)

DeitTiny Model for hierarchical label classification task.

VisionTransformerForMulticlassCls(...)

DeitTiny Model for multi-class classification task.

VisionTransformerForMultilabelCls(...)

DeitTiny Model for multi-class classification task.

class otx.algo.classification.vit.ForwardExplainMixInForViT[source]#

Bases: Generic[T_OTXBatchPredEntity, T_OTXBatchDataEntity]

ViT model which can attach a XAI (Explainable AI) branch.

forward_explain(inputs: T_OTXBatchDataEntity) T_OTXBatchPredEntity[source]#

Model forward function.

forward_for_tracing(image: Tensor) Tensor | dict[str, Tensor][source]#

Model forward function used for the model tracing during model exportation.

get_explain_fn() Callable[source]#

Returns explain function.

head_forward_fn(x: Tensor) Tensor[source]#

Performs model’s neck and head forward.

property has_gap: bool#

Defines if GAP is used right after backbone.

Note

Can be redefined at the model’s level.

class otx.algo.classification.vit.VisionTransformerForHLabelCls(label_info: HLabelInfo, arch: VIT_ARCH_TYPE = 'vit-tiny', lora: bool = False, pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _mixed_hlabel_accuracy>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: ForwardExplainMixInForViT, OTXHlabelClsModel

DeitTiny Model for hierarchical label classification task.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.vit.VisionTransformerForMulticlassCls(label_info: LabelInfoTypes, arch: VIT_ARCH_TYPE = 'vit-tiny', lora: bool = False, pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_class_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224), train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED)[source]#

Bases: ForwardExplainMixInForViT, OTXMulticlassClsModel

DeitTiny Model for multi-class classification task.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.

class otx.algo.classification.vit.VisionTransformerForMultilabelCls(label_info: LabelInfoTypes, arch: VIT_ARCH_TYPE = 'vit-tiny', lora: bool = False, pretrained: bool = True, optimizer: OptimizerCallable = <function _default_optimizer_callable>, scheduler: LRSchedulerCallable | LRSchedulerListCallable = <function _default_scheduler_callable>, metric: MetricCallable = <function _multi_label_cls_metric_callable>, torch_compile: bool = False, input_size: tuple[int, int] = (224, 224))[source]#

Bases: ForwardExplainMixInForViT, OTXMultilabelClsModel

DeitTiny Model for multi-class classification task.

load_from_otx_v1_ckpt(state_dict: dict, add_prefix: str = 'model.') dict[source]#

Load the previous OTX ckpt according to OTX2.0.