Source code for otx.core.ov.models.mmov_model
"""MMOVModel for otx.core.ov.models.mmov_model."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Union
import openvino.runtime as ov
import torch
# TODO: Need to remove line 1 (ignore mypy) and fix mypy issues
from .ov_model import OVModel # type: ignore[attr-defined]
from .parser_mixin import ParserMixin # type: ignore[attr-defined]
# TODO: Need to fix pylint issues
# pylint: disable=keyword-arg-before-vararg
[docs]
class MMOVModel(OVModel, ParserMixin):
"""MMOVModel for OMZ model type."""
def __init__(
self,
model_path_or_model: Union[str, ov.Model],
weight_path: Optional[str] = None,
inputs: Optional[Union[Dict[str, Union[str, List[str]]], List[str], str]] = None,
outputs: Optional[Union[Dict[str, Union[str, List[str]]], List[str], str]] = None,
*args,
**kwargs
):
parser = kwargs.pop("parser", None)
parser_kwargs = kwargs.pop("parser_kwargs", {})
inputs, outputs = super().parse(
model_path_or_model=model_path_or_model,
weight_path=weight_path,
inputs=inputs,
outputs=outputs,
parser=parser,
**parser_kwargs,
)
super().__init__(
model_path_or_model=model_path_or_model,
weight_path=weight_path,
inputs=inputs,
outputs=outputs,
*args,
**kwargs,
)
[docs]
def forward(self, inputs, gt_label=None):
"""Function forward."""
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
assert len(inputs) == len(self.inputs)
feed_dict = dict()
for key, input_ in zip(self.inputs, inputs):
feed_dict[key] = input_
if gt_label is not None:
assert "gt_label" not in self.features
self.features["gt_label"] = gt_label
outputs = super().forward(**feed_dict)
outputs = tuple(outputs.values())
if len(outputs) == 1:
outputs = outputs[0]
return outputs