"""CLI entrypoints."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import sys
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
from warnings import warn
import yaml
from jsonargparse import ActionConfigFile, ArgumentParser, Namespace, namespace_to_dict
from rich.console import Console
from otx import OTX_LOGO, __version__
from otx.cli.utils import absolute_path
from otx.cli.utils.help_formatter import CustomHelpFormatter
from otx.cli.utils.jsonargparse import get_short_docstring, patch_update_configs
from otx.cli.utils.workspace import Workspace
from otx.core.types.task import OTXTaskType
from otx.core.utils.imports import get_otx_root_path
if TYPE_CHECKING:
from jsonargparse._actions import _ActionSubCommands
from otx.core.data.module import OTXDataModule
from otx.core.model.base import OTXModel
_ENGINE_AVAILABLE = True
try:
from otx.core.config import register_configs
from otx.engine import Engine
register_configs()
except ImportError:
_ENGINE_AVAILABLE = False
[docs]
class OTXCLI:
"""OTX CLI entrypoint."""
datamodule: OTXDataModule
def __init__(self, args: list[str] | None = None, run: bool = True) -> None:
"""Initialize OTX CLI."""
self.console = Console()
self._subcommand_method_arguments: dict[str, list[str]] = {}
with patch_update_configs():
self.parser = self.init_parser()
self.add_subcommands()
self.config = self.parser.parse_args(args=args, _skip_check=True)
self.subcommand = self.config["subcommand"]
if run:
self.run()
[docs]
def init_parser(self) -> ArgumentParser:
"""Initialize the argument parser for the OTX CLI.
Returns:
ArgumentParser: The initialized argument parser.
"""
parser = ArgumentParser(
description="OpenVINO Training-Extension command line tool",
env_prefix="otx",
parser_mode="omegaconf",
formatter_class=CustomHelpFormatter,
)
parser.add_argument(
"-v",
"--version",
action="version",
version=f"%(prog)s {__version__}",
help="Display OTX version number.",
)
return parser
[docs]
@staticmethod
def engine_subcommand_parser(subcommand: str, **kwargs) -> tuple[ArgumentParser, list]:
"""Creates an ArgumentParser object for the engine subcommand.
Args:
**kwargs: Additional keyword arguments to be passed to the ArgumentParser constructor.
Returns:
ArgumentParser: The created ArgumentParser object.
"""
parser = ArgumentParser(
formatter_class=CustomHelpFormatter,
parser_mode="omegaconf",
**kwargs,
)
parser.add_argument(
"-v",
"--verbose",
action="count",
help="Verbose mode. This shows a configuration argument that allows for more specific overrides. \
Multiple -v options increase the verbosity. The maximum is 2.",
)
parser.add_argument(
"-c",
"--config",
action=ActionConfigFile,
help="Path to a configuration file in json or yaml format.",
)
parser.add_argument(
"--data_root",
type=absolute_path,
help="Path to dataset root.",
)
parser.add_argument(
"--work_dir",
type=absolute_path,
default=absolute_path(Path.cwd()),
help="Path to work directory. The default is created as otx-workspace.",
)
parser.add_argument(
"--task",
type=str,
help="Task Type.",
)
parser.add_argument(
"--seed",
type=int,
help="Sets seed for pseudo-random number generators in: pytorch, numpy, python.random.",
)
parser.add_argument(
"--callback_monitor",
type=str,
help="The metric to monitor the model performance during training callbacks.",
)
parser.add_argument(
"--disable-infer-num-classes",
help="OTX automatically infers num_classes from the given dataset "
"and applies it to the model initialization."
"Consequently, there might be a mismatch with the provided model configuration during runtime. "
"Setting this option to true will disable this behavior.",
action="store_true",
)
engine_skip = {"model", "datamodule", "work_dir"}
parser.add_class_arguments(
Engine,
"engine",
fail_untyped=False,
sub_configs=True,
instantiate=False,
skip=engine_skip,
)
# Model Settings
from otx.core.model.base import OTXModel
parser.add_subclass_arguments(
OTXModel,
"model",
required=False,
fail_untyped=False,
)
# Datamodule Settings
from otx.core.data.module import OTXDataModule
parser.add_class_arguments(
OTXDataModule,
"data",
fail_untyped=False,
sub_configs=True,
)
parser.add_class_arguments(Workspace, "workspace")
parser.link_arguments("work_dir", "workspace.work_dir")
parser.link_arguments("data_root", "engine.data_root")
parser.link_arguments("data_root", "data.data_root")
parser.link_arguments("engine.device", "data.device")
added_arguments = parser.add_method_arguments(
Engine,
subcommand,
skip=set(OTXCLI.engine_subcommands()[subcommand]),
fail_untyped=False,
)
if "callbacks" in added_arguments:
parser.link_arguments("callback_monitor", "callbacks.init_args.monitor")
parser.link_arguments("workspace.work_dir", "callbacks.init_args.dirpath", apply_on="instantiate")
if "logger" in added_arguments:
parser.link_arguments("workspace.work_dir", "logger.init_args.save_dir", apply_on="instantiate")
parser.link_arguments("workspace.work_dir", "logger.init_args.log_dir", apply_on="instantiate")
if "checkpoint" in added_arguments and "--checkpoint" in sys.argv:
# This is code for an OVModel that uses checkpoint in model.model_name.
parser.link_arguments("checkpoint", "model.init_args.model_name")
# Load default subcommand config file
default_config_file = get_otx_root_path() / "recipe" / "_base_" / f"{subcommand}.yaml"
if default_config_file.exists():
with Path(default_config_file).open() as f:
default_config = yaml.safe_load(f)
parser.set_defaults(**default_config)
return parser, added_arguments
[docs]
@staticmethod
def engine_subcommands() -> dict[str, set[str]]:
"""Returns dictionary the subcommands of engine, and whose value is the argument to be skipped in the CLI.
This allows the CLI to skip duplicate keys when creating the Engine and when running the subcommand.
Returns:
A dictionary where the keys are the subcommands and the values are sets of skipped arguments.
"""
device_kwargs = {"accelerator", "devices"}
return {
"train": {"seed"}.union(device_kwargs),
"test": {"datamodule"}.union(device_kwargs),
"predict": {"datamodule"}.union(device_kwargs),
"export": device_kwargs,
"optimize": {"datamodule"}.union(device_kwargs),
"explain": {"datamodule"}.union(device_kwargs),
"benchmark": device_kwargs,
}
[docs]
def add_subcommands(self) -> None:
"""Adds subcommands to the CLI parser.
This method initializes and configures subcommands for the OTX CLI parser.
It iterates over the available subcommands, adds arguments specific to each subcommand,
and registers them with the parser.
Returns:
None
"""
self._subcommand_parsers: dict[str, ArgumentParser] = {}
parser_subcommands = self.parser.add_subcommands()
self._set_extension_subcommands_parser(parser_subcommands)
if not _ENGINE_AVAILABLE:
# If environment is not configured to use Engine, do not add a subcommand for Engine.
return
for subcommand in self.engine_subcommands():
# If already have a workspace or run it from the root of a workspace, utilize config and checkpoint in cache
root_dir = Path(sys.argv[sys.argv.index("--work_dir") + 1]) if "--work_dir" in sys.argv else Path.cwd()
self.cache_dir = root_dir / ".latest" / "train" # The config and checkpoint used in the latest training.
parser_kwargs = self._set_default_config()
sub_parser, added_arguments = self.engine_subcommand_parser(subcommand=subcommand, **parser_kwargs)
if "--config" not in sys.argv and "checkpoint" in added_arguments and self.cache_dir.exists():
# If the user specifies the config directly, not set the cache ckpt as default.
self._load_cache_ckpt(parser=sub_parser)
fn = getattr(Engine, subcommand)
description = get_short_docstring(fn)
self._subcommand_method_arguments[subcommand] = added_arguments
self._subcommand_parsers[subcommand] = sub_parser
parser_subcommands.add_subcommand(subcommand, sub_parser, help=description)
def _load_cache_ckpt(self, parser: ArgumentParser) -> None:
checkpoint_dir = self.cache_dir / "checkpoints"
if not checkpoint_dir.exists():
return
ckpt_files = list(checkpoint_dir.glob("epoch_*.ckpt"))
if not ckpt_files:
return
latest_checkpoint = max(ckpt_files, key=lambda p: p.stat().st_mtime)
parser.set_defaults(checkpoint=str(latest_checkpoint))
if "--print_config" not in sys.argv:
warn(f"Load default checkpoint from {latest_checkpoint}.", stacklevel=0)
def _set_default_config(self) -> dict:
parser_kwargs = {}
if "--config" not in sys.argv and (self.cache_dir / "configs.yaml").exists():
parser_kwargs["default_config_files"] = [str(self.cache_dir / "configs.yaml")]
if "--print_config" not in sys.argv:
warn(f"Load default config from {self.cache_dir / 'configs.yaml'}.", stacklevel=0)
return parser_kwargs
# If don't use cache, use the default config from auto configuration.
data_root = None
task = None
if "--data_root" in sys.argv:
data_root = sys.argv[sys.argv.index("--data_root") + 1]
if "--task" in sys.argv:
task = sys.argv[sys.argv.index("--task") + 1]
enable_auto_config = data_root is not None and "--config" not in sys.argv
if enable_auto_config:
from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator
auto_configurator = AutoConfigurator(
data_root=data_root,
task=OTXTaskType(task) if task is not None else task,
)
config_file_path = DEFAULT_CONFIG_PER_TASK[auto_configurator.task]
parser_kwargs["default_config_files"] = [str(config_file_path)]
return parser_kwargs
def _set_extension_subcommands_parser(self, parser_subcommands: _ActionSubCommands) -> None:
from otx.cli.install import add_install_parser
add_install_parser(parser_subcommands)
if _ENGINE_AVAILABLE:
# `otx find` arguments
find_parser = ArgumentParser(formatter_class=CustomHelpFormatter)
find_parser.add_argument(
"--task",
help="Value for filtering by task. Default is None, which shows all recipes.",
type=Optional[OTXTaskType],
)
find_parser.add_argument(
"--pattern",
help="This allows you to filter the model name of the recipe. \
For example, if you want to find all models that contain the word 'efficient', \
you can use '--pattern efficient'",
type=Optional[str],
)
parser_subcommands.add_subcommand("find", find_parser, help="This shows the model provided by OTX.")
[docs]
def instantiate_classes(self, instantiate_engine: bool = True) -> None:
"""Instantiate the necessary classes based on the subcommand.
This method checks if the subcommand is one of the engine subcommands.
If it is, it instantiates the necessary classes such as config, datamodule, model, and engine.
Args:
instantiate_engine (bool, optional): Whether to instantiate the engine. Defaults to True.
"""
if self.subcommand in self.engine_subcommands():
# For num_classes update, Model and Metric are instantiated separately.
model_config = self.config[self.subcommand].pop("model")
# if adaptive_input_size will be executed and the model has input_size_multiplier, pass it to OTXDataModule
if self.config[self.subcommand].data.get("adaptive_input_size") is not None:
from otx.utils.utils import get_model_cls_from_config
model_cls = get_model_cls_from_config(model_config)
self.config[self.subcommand].data.input_size_multiplier = model_cls.input_size_multiplier
# Instantiate the things that don't need to special handling
self.config_init = self.parser.instantiate_classes(self.config)
self.workspace = self.get_config_value(self.config_init, "workspace")
self.datamodule = self.get_config_value(self.config_init, "data")
# pass OTXDataModule input size to the model
if (input_size := self.datamodule.input_size) is not None and "input_size" in model_config["init_args"]:
model_config["init_args"]["input_size"] = (
(input_size, input_size) if isinstance(input_size, int) else tuple(input_size)
)
# Instantiate the model and needed components
self.model = self.instantiate_model(model_config=model_config)
if instantiate_engine:
self.engine = self.instantiate_engine()
[docs]
def instantiate_engine(self) -> Engine:
"""Instantiate an Engine object with the specified parameters.
Returns:
An instance of the Engine class.
"""
engine_kwargs = self.get_config_value(self.config_init, "engine")
return Engine(
model=self.model,
datamodule=self.datamodule,
work_dir=self.workspace.work_dir,
**engine_kwargs,
)
[docs]
def instantiate_model(self, model_config: Namespace) -> OTXModel:
"""Instantiate the model based on the subcommand.
This method checks if the subcommand is one of the engine subcommands.
If it is, it instantiates the model.
Args:
model_config (Namespace): The model configuration.
Returns:
tuple: The model and optimizer and scheduler.
"""
from otx.core.model.base import OTXModel
from otx.utils.utils import can_pass_tile_config, get_model_cls_from_config, should_pass_label_info
skip = set()
# Update label_info
model_cls = get_model_cls_from_config(model_config)
if should_pass_label_info(model_cls) and not self.get_config_value(
self.config_init,
"disable_infer_num_classes",
False,
):
model_config.init_args.label_info = self.datamodule.label_info
warning_msg = (
"Automatically infer label_info from the given dataset. "
"Then, giving it to the OTXModel.__init__() argument. "
"If you don't want this behavior, please use `--disable-infer-num-classes` option."
)
warn(warning_msg, stacklevel=0)
skip.add("label_info")
# Update tile config due to adaptive tiling
if can_pass_tile_config(model_cls):
model_config.init_args.tile_config = self.datamodule.tile_config
skip.add("tile_config")
# NOTE: Workaround for jsonargparse cannot parse lambda default with unknown reasons
optimizer_arg, scheduler_arg = model_config.init_args.get("optimizer"), model_config.init_args.get("scheduler")
if isinstance(optimizer_arg, str) and optimizer_arg.endswith("<lambda>"):
model_config.init_args.pop("optimizer")
if isinstance(scheduler_arg, str) and scheduler_arg.endswith("<lambda>"):
model_config.init_args.pop("scheduler")
# Parses the OTXModel separately to update num_classes.
model_parser = ArgumentParser()
model_parser.add_subclass_arguments(OTXModel, "model", skip=skip, required=False, fail_untyped=False)
model: OTXModel = model_parser.instantiate_classes(Namespace(model=model_config)).get("model")
self.config_init[self.subcommand]["model"] = model
# Update self.config with model
self.config[self.subcommand].update(Namespace(model=model_config))
return model
[docs]
def get_config_value(self, config: Namespace, key: str, default: Any = None) -> Any: # noqa: ANN401
"""Retrieves the value of a configuration key from the given config object.
Args:
config (Namespace): The config object containing the configuration values.
key (str): The key of the configuration value to retrieve.
default (Any, optional): The default value to return if the key is not found. Defaults to None.
Returns:
Any: The value of the configuration key, or the default value if the key is not found.
if the value is a Namespace, it is converted to a dictionary.
"""
result = config.get(str(self.subcommand), config).get(key, default)
return namespace_to_dict(result) if isinstance(result, Namespace) else result
[docs]
def get_subcommand_parser(self, subcommand: str | None) -> ArgumentParser:
"""Returns the argument parser for the specified subcommand.
Args:
subcommand (str | None): The name of the subcommand. If None, returns the main parser.
Returns:
ArgumentParser: The argument parser for the specified subcommand.
"""
if subcommand is None:
return self.parser
# return the subcommand parser for the subcommand passed
return self._subcommand_parsers[subcommand]
[docs]
def prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]:
"""Prepares the keyword arguments to pass to the subcommand to run."""
return {
k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand]
}
[docs]
def save_config(self, work_dir: Path) -> None:
"""Save the configuration for the specified subcommand.
Args:
work_dir (Path): The working directory where the configuration file will be saved.
The configuration is saved as a YAML file in the engine's working directory.
"""
self.config[self.subcommand].pop("workspace", None)
self.config[self.subcommand]["work_dir"] = str(self.workspace.work_dir.parent)
# TODO(vinnamki): Revisit it after changing the optimizer and scheduler instantiating.
cfg = deepcopy(self.config.get(str(self.subcommand), self.config))
cfg.model.init_args.pop("optimizer")
cfg.model.init_args.pop("scheduler")
cfg.model.init_args.pop("label_info")
cfg.model.init_args.pop("tile_config")
self.get_subcommand_parser(self.subcommand).save(
cfg=cfg,
path=work_dir / "configs.yaml",
overwrite=True,
multifile=False,
skip_check=True,
)
# if train -> Update `.latest` folder
self.update_latest(work_dir=work_dir)
[docs]
def update_latest(self, work_dir: Path) -> None:
"""Update the latest cache directory with the latest configurations and checkpoint file.
Args:
work_dir (Path): The working directory where the configurations and checkpoint files are located.
"""
latest_dir = work_dir.parent / ".latest"
latest_dir.mkdir(exist_ok=True)
cache_dir = latest_dir / self.subcommand
if cache_dir.exists():
cache_dir.unlink()
cache_dir.symlink_to(Path("..") / work_dir.relative_to(work_dir.parent))
[docs]
def set_seed(self) -> None:
"""Set the random seed for reproducibility.
This method retrieves the seed value from the argparser and uses it to set the random seed.
If a seed value is provided, it will be used to set the random seed using the
`seed_everything` function from the `lightning` module.
"""
seed = self.get_config_value(self.config, "seed", None)
if seed is not None:
from lightning import seed_everything
seed_everything(seed, workers=True)
[docs]
def run(self) -> None:
"""Executes the specified subcommand.
Raises:
ValueError: If the subcommand is not recognized.
"""
self.console.print(f"[blue]{OTX_LOGO}[/blue] ver.{__version__}", justify="center")
if self.subcommand == "install":
from otx.cli.install import otx_install
otx_install(**self.config["install"])
elif self.subcommand == "find":
from otx.engine.utils.api import list_models
list_models(print_table=True, **self.config[self.subcommand])
elif self.subcommand in self.engine_subcommands():
self.set_seed()
self.instantiate_classes()
fn_kwargs = self.prepare_subcommand_kwargs(self.subcommand)
fn = getattr(self.engine, self.subcommand)
try:
outputs = fn(**fn_kwargs)
self._print_results(outputs=outputs)
except Exception:
self.console.print_exception(width=self.console.width)
raise
self.save_config(work_dir=Path(self.engine.work_dir))
else:
msg = f"Unrecognized subcommand: {self.subcommand}"
raise ValueError(msg)
def _print_results(self, outputs: Any) -> None: # noqa: ANN401
if outputs is None:
return
if self.subcommand == "train" and isinstance(outputs, dict):
# Print Metric like 'otx test'
from rich.table import Column, Table
from torch import Tensor
table_headers = ["Train metric", "Value"]
columns = [Column(h, justify="center", style="magenta", width=self.console.width) for h in table_headers]
columns[0].style = "cyan"
table = Table(*columns)
for metric, row in outputs.items():
if isinstance(row, Tensor):
row = row.item() if row.numel() == 1 else row.tolist() # noqa: PLW2901
table.add_row(*[metric, f"{row}"])
self.console.print(table)
elif self.subcommand in ("export", "optimize"):
# Print output model path
self.console.print(f"{self.subcommand} output: {outputs}")
self.console.print(f"Work Directory: {self.engine.work_dir}")