"""Model training tool."""
# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
# pylint: disable=too-many-locals
import datetime
import time
from contextlib import ExitStack
from pathlib import Path
from typing import Optional
# Update environment variables for CLI use
import otx.cli # noqa: F401
from otx.api.entities.model import ModelEntity
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.entities.train_parameters import TrainParameters
from otx.api.serialization.label_mapper import label_schema_to_bytes
from otx.api.usecases.adapters.model_adapter import ModelAdapter
from otx.cli.manager import ConfigManager
from otx.cli.manager.config_manager import TASK_TYPE_TO_SUB_DIR_NAME
from otx.cli.utils.experiment import ResourceTracker
from otx.cli.utils.hpo import run_hpo
from otx.cli.utils.importing import get_impl_class
from otx.cli.utils.io import read_binary, read_label_schema, save_model_data
from otx.cli.utils.multi_gpu import MultiGPUManager, is_multigpu_child_process
from otx.cli.utils.parser import (
MemSizeAction,
add_hyper_parameters_sub_parser,
get_override_param,
get_parser_and_hprams_data,
)
from otx.cli.utils.report import get_otx_report
from otx.core.data.adapter import get_dataset_adapter
from otx.utils.logger import config_logger
[docs]
def get_args():
"""Parses command line arguments."""
parser, hyper_parameters, params = get_parser_and_hprams_data()
parser.add_argument(
"--train-data-roots",
help="Comma-separated paths to training data folders.",
)
parser.add_argument("--train-ann-files", help="Comma-separated paths to train annotation files.")
parser.add_argument(
"--val-data-roots",
help="Comma-separated paths to validation data folders.",
)
parser.add_argument("--val-ann-files", help="Comma-separated paths to train annotation files.")
parser.add_argument(
"--unlabeled-data-roots",
help="Comma-separated paths to unlabeled data folders",
)
parser.add_argument(
"--unlabeled-file-list",
help="Comma-separated paths to unlabeled file list",
)
parser.add_argument(
"--train-type",
help=f"The currently supported options: {TASK_TYPE_TO_SUB_DIR_NAME.keys()}. "
"Will be difined automatically if no value passed.",
type=str,
default=None,
)
parser.add_argument(
"--load-weights",
help="Load model weights from previously saved checkpoint.",
)
parser.add_argument(
"--resume-from",
help="Resume training from previously saved checkpoint",
)
parser.add_argument(
"-o",
"--output",
help="Location where outputs (model & logs) will be stored.",
)
parser.add_argument(
"--workspace",
help="Location where the intermediate output of the training will be stored.",
default=None,
)
parser.add_argument(
"--enable-hpo",
action="store_true",
help="Execute hyper parameters optimization (HPO) before training.",
)
parser.add_argument(
"--hpo-time-ratio",
default=4,
type=float,
help="Expected ratio of total time to run HPO to time taken for full fine-tuning.",
)
parser.add_argument(
"--gpus",
type=str,
help="Comma-separated indices of GPU. \
If there are more than one available GPU, then model is trained with multi GPUs.",
)
parser.add_argument(
"--rdzv-endpoint",
type=str,
default="localhost:0",
help="Rendezvous endpoint for multi-node training.",
)
parser.add_argument(
"--base-rank",
type=int,
default=0,
help="Base rank of the current node workers.",
)
parser.add_argument(
"--world-size",
type=int,
default=0,
help="Total number of workers in a worker group.",
)
parser.add_argument(
"--mem-cache-size",
action=MemSizeAction,
dest="params.algo_backend.mem_cache_size",
type=str,
required=False,
help="Size of memory pool for caching decoded data to load data faster. "
"For example, you can use digits for bytes size (e.g. 1024) or a string with size units "
"(e.g. 7KiB = 7 * 2^10, 3MB = 3 * 10^6, and 2G = 2 * 2^30).",
)
parser.add_argument(
"--deterministic",
action="store_true",
help="Set deterministic to True, default=False.",
)
parser.add_argument(
"--seed",
type=int,
help="Set seed for training.",
)
parser.add_argument(
"--data",
type=str,
default=None,
help="The data.yaml path want to use in train task.",
)
parser.add_argument(
"--encryption-key",
type=str,
default=None,
help="Encryption key required to train the encrypted dataset. It is not required the non-encrypted dataset",
)
parser.add_argument(
"--track-resource-usage",
type=str,
default=None,
help="Track resources utilization and max memory usage and save values at the output path. "
"The possible options are 'cpu', 'gpu' or you can set to a comma-separated list of resource types. "
"And 'all' is also available for choosing all resource types.",
)
sub_parser = add_hyper_parameters_sub_parser(parser, hyper_parameters, return_sub_parser=True)
# TODO: Temporary solution for cases where there is no template input
override_param = get_override_param(params)
if not hyper_parameters and "params" in params:
if "params" in params:
params = params[params.index("params") :]
for param in params:
if param == "--help":
print("Without template configuration, hparams information is unknown.")
elif param.startswith("--"):
sub_parser.add_argument(
f"{param}",
dest=f"params.{param[2:]}",
)
return parser.parse_args(), override_param
[docs]
def main():
"""Main function that invoke train function with ExitStack."""
with ExitStack() as exit_stack:
return train(exit_stack)
[docs]
def train(exit_stack: Optional[ExitStack] = None): # pylint: disable=too-many-branches, too-many-statements
"""Function that is used for model training."""
start_time = time.time()
mode = "train"
args, override_param = get_args()
config_manager = ConfigManager(args, workspace_root=args.workspace, mode=mode)
config_logger(config_manager.output_path / "otx.log", "INFO")
# Auto-Configuration for model template
config_manager.configure_template()
# Creates a workspace if it doesn't exist.
if not config_manager.check_workspace():
config_manager.build_workspace(new_workspace_path=args.workspace)
# Update Hyper Parameter Configs
hyper_parameters = config_manager.get_hyparams_config(override_param=override_param)
# Auto-Configuration for Dataset configuration
config_manager.configure_data_config(update_data_yaml=config_manager.check_workspace())
dataset_config = config_manager.get_dataset_config(
subsets=["train", "val", "unlabeled"],
hyper_parameters=hyper_parameters,
)
dataset_adapter = get_dataset_adapter(**dataset_config)
dataset, label_schema = dataset_adapter.get_otx_dataset(), dataset_adapter.get_label_schema()
# Get classes for Task, ConfigurableParameters and Dataset.
template = config_manager.template
task_class = get_impl_class(template.entrypoints.base)
environment = TaskEnvironment(
model=None,
hyper_parameters=hyper_parameters,
label_schema=label_schema,
model_template=template,
)
if args.load_weights or args.resume_from:
ckpt_path = args.resume_from if args.resume_from else args.load_weights
model_adapters = {
"path": ckpt_path,
"weights.pth": ModelAdapter(read_binary(ckpt_path)),
"resume": bool(args.resume_from),
}
if (Path(ckpt_path).parent / "label_schema.json").exists():
model_adapters.update(
{"label_schema.json": ModelAdapter(label_schema_to_bytes(read_label_schema(ckpt_path)))}
)
environment.model = ModelEntity(
train_dataset=dataset,
configuration=environment.get_model_configuration(),
model_adapters=model_adapters, # type: ignore
)
if args.enable_hpo:
environment = run_hpo(
args.hpo_time_ratio, config_manager.output_path, environment, dataset, config_manager.data_config
)
(config_manager.output_path / "logs").mkdir(exist_ok=True, parents=True)
if args.gpus:
multigpu_manager = MultiGPUManager(
train,
args.gpus,
args.rdzv_endpoint,
args.base_rank,
args.world_size,
datetime.datetime.fromtimestamp(start_time),
)
if (
multigpu_manager.is_available()
and not template.task_type.is_anomaly # anomaly tasks don't use this way for multi-GPU training
):
multigpu_manager.setup_multi_gpu_train(
str(config_manager.output_path), hyper_parameters if args.enable_hpo else None
)
if exit_stack is not None:
exit_stack.callback(multigpu_manager.finalize)
else:
print(
"Warning: due to abstract of ExitStack context, "
"if main process raises an error, all processes can be stuck."
)
task = task_class(task_environment=environment, output_path=str(config_manager.output_path / "logs"))
output_model = ModelEntity(dataset, environment.get_model_configuration())
resource_tracker = None
if args.track_resource_usage and not is_multigpu_child_process():
resource_tracker = ResourceTracker(
config_manager.output_path / "resource_usage.yaml", args.track_resource_usage, args.gpus
)
resource_tracker.start()
if exit_stack is not None:
exit_stack.callback(resource_tracker.stop)
task.train(
dataset, output_model, train_parameters=TrainParameters(), seed=args.seed, deterministic=args.deterministic
)
if resource_tracker is not None and exit_stack is None:
resource_tracker.stop()
model_path = config_manager.output_path / "models"
save_model_data(output_model, str(model_path))
end_time = time.time()
sec = end_time - start_time
total_time = str(datetime.timedelta(seconds=sec))
print("otx train time elapsed: ", total_time)
model_results = {
"time elapsed": total_time,
"score": output_model.performance,
"model_path": str(model_path.absolute()),
}
if args.gpus and exit_stack is None:
multigpu_manager.finalize()
elif is_multigpu_child_process():
return
get_otx_report(
model_template=config_manager.template,
task_config=task.config,
data_config=config_manager.data_config,
results=model_results,
output_path=config_manager.output_path / "cli_report.log",
)
print(f"otx train CLI report has been generated: {config_manager.output_path / 'cli_report.log'}")
# Latest model folder symbolic link to models
latest_path = config_manager.workspace_root / "outputs" / "latest_trained_model"
if latest_path.exists():
latest_path.unlink()
elif not latest_path.parent.exists():
latest_path.parent.mkdir(exist_ok=True, parents=True)
latest_path.symlink_to(config_manager.output_path.resolve())
if not is_multigpu_child_process():
task.cleanup()
return dict(retcode=0, template=template.name)
if __name__ == "__main__":
main()