Source code for otx.core.ov.utils
# type: ignore
# TODO: Need to remove line 1 (ignore mypy) and fix mypy issues
"""Utils for otx.core.ov."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
import errno
import os
from typing import Optional
from openvino.runtime import Core, Model, Node
from .omz_wrapper import AVAILABLE_OMZ_MODELS, get_omz_model
# pylint: disable=too-many-locals
[docs]
def to_dynamic_model(ov_model: Model) -> Model:
"""Convert ov_model to dynamic Model."""
assert isinstance(ov_model, Model)
shapes = {}
target_layouts = {}
for input_node in ov_model.inputs:
target_layout = {
"batch": ["N", None, None],
"height": ["H", None, None],
"width": ["W", None, None],
}
any_name = input_node.any_name
parameter_node = input_node.get_node()
layout = parameter_node.get_layout()
if layout.empty:
continue
layout = layout.to_string()[1:-1].split(",")
shape = [str(i) for i in input_node.get_partial_shape()]
for i, (layout_name, shape_) in enumerate(zip(layout, shape)):
try:
shape_ = int(shape_)
except ValueError:
shape_ = -1
for target_layout_ in target_layout.values():
target_layout_name = target_layout_[0]
if layout_name == target_layout_name:
target_layout_[1] = i
target_layout_[2] = shape_
shape_ = -1
break
shape[i] = shape_
shapes[any_name] = shape
target_layouts[any_name] = target_layout
def reshape_model(ov_model, shapes):
try:
ov_model.reshape(shapes)
return True
except Exception: # pylint: disable=broad-exception-caught
return False
pop_targets = [["height", "width"], ["batch"]]
pop_targets = pop_targets[::-1]
while not reshape_model(ov_model, shapes):
for key, shape in shapes.items():
target_layout = target_layouts[key]
targets = pop_targets.pop()
for target in targets:
target_idx, target_origin = target_layout[target][1:]
if target_idx is not None:
shape[target_idx] = target_origin
if len(pop_targets) == 0:
reshape_model(ov_model, shapes)
break
return ov_model
[docs]
def load_ov_model(model_path: str, weight_path: Optional[str] = None, convert_dynamic: bool = False) -> Model:
"""Load ov_model from model_path."""
model_path = str(model_path)
if model_path.startswith("omz://"):
model_path = model_path.replace("omz://", "")
assert model_path in AVAILABLE_OMZ_MODELS
ov_ir_path = get_omz_model(model_path)
model_path = ov_ir_path["model_path"]
weight_path = ov_ir_path["weight_path"]
if weight_path is None:
weight_path = os.path.splitext(model_path)[0] + ".bin"
if not os.path.exists(model_path):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), model_path)
if not os.path.exists(weight_path):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), weight_path)
ie_core = Core()
ov_model = ie_core.read_model(model=model_path, weights=weight_path)
if convert_dynamic:
ov_model = to_dynamic_model(ov_model)
return ov_model
[docs]
def normalize_name(name: str) -> str:
"""Normalize name string."""
# ModuleDict does not allow '.' in module name string
name = name.replace(".", "#")
return f"{name}"
[docs]
def unnormalize_name(name: str) -> str:
"""Unnormalize name string."""
name = name.replace("#", ".")
return name
[docs]
def get_op_name(op_node: Node) -> str:
"""Get op name string."""
op_name = op_node.get_friendly_name()
op_name = normalize_name(op_name)
return op_name