Source code for otx.cli.registry.registry
"""Model templates registry."""
# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import copy
import glob
import os
from pathlib import Path
from typing import Optional
import yaml
from otx.api.entities.model_template import parse_model_template
from otx.cli.utils.importing import get_backbone_list, get_otx_root_path
[docs]
class Registry:
"""Class that implements a model templates registry."""
def __init__(self, templates_dir=None, templates=None, experimental=False):
if templates is None:
if templates_dir is None:
templates_dir = os.getenv("TEMPLATES_DIR")
if templates_dir is None:
raise RuntimeError("The templates_dir is not set.")
template_filenames = glob.glob(os.path.join(templates_dir, "**", "template.yaml"), recursive=True)
if experimental:
template_filenames.extend(
glob.glob(
os.path.join(templates_dir, "**", "template_experimental.yaml"),
recursive=True,
)
)
template_filenames = [os.path.abspath(p) for p in template_filenames]
self.templates = []
for template_file in template_filenames:
self.templates.append(parse_model_template(template_file))
else:
self.templates = copy.deepcopy(templates)
self.task_types = self.__collect_task_types(self.templates)
@staticmethod
def __collect_task_types(templates):
return {template.task_type for template in templates}
[docs]
def filter(self, framework=None, task_type=None):
"""Filters registry by framework and/or task type and returns filtered registry."""
templates = copy.deepcopy(self.templates)
if framework is not None:
templates = [template for template in templates if template.framework.lower() == framework.lower()]
if task_type is not None:
templates = [template for template in templates if str(template.task_type).lower() == task_type.lower()]
return Registry(templates=templates)
[docs]
def get(self, template_id, skip_error=False):
"""Returns a model template with specified template_id or template.name."""
templates = [
template
for template in self.templates
if str(template_id).upper() in (str(template.model_template_id).upper(), str(template.name).upper())
]
if not templates:
if skip_error:
return None
raise ValueError(f"Could not find a template with {template_id} in registry.")
return templates[0]
[docs]
def get_backbones(self, backend_list):
"""Returns list of backbones for a given template."""
backbone_list = {}
for backend in backend_list:
backbone_list[backend] = get_backbone_list(backend)
return backbone_list
def __str__(self):
"""Returns the string representation of the registry."""
templates_infos = [
{
"name": t.name,
"id": t.model_template_id,
"path": t.model_template_path,
"task_type": str(t.task_type),
}
for t in self.templates
]
return yaml.dump(templates_infos)
[docs]
def find_and_parse_model_template(path_or_id):
"""In first function attempts to read a model template from disk under assumption that a path is passed.
If the attempt is failed, it tries to find template in registry under assumption that an ID or name is passed.
"""
# Return None Type
if not path_or_id:
return path_or_id
# 1. Find from path
if is_template(path_or_id):
return parse_model_template(path_or_id)
# 2. Find from id or Name
return Registry(get_otx_root_path()).get(path_or_id, skip_error=True)
def is_template(template_path: Optional[str]) -> bool:
"""A function that determines whether the corresponding template path is a template.
Args:
template_path (str): The path of the file you want to know if it is a template.
Returns:
bool: True if template_path is template file else False.
"""
if template_path and Path(template_path).is_file() and "template" in Path(template_path).name:
return True
return False