Source code for otx.cli.utils.config

"""Utils for working with Configurable parameters."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from pathlib import Path

import yaml


[docs] def override_parameters(overrides, parameters): """Overrides parameters values by overrides.""" allowed_keys = {"default_value", "value"} for k, val in overrides.items(): if isinstance(val, dict): if k in parameters.keys(): override_parameters(val, parameters[k]) else: raise ValueError(f'The "{k}" is not in original parameters.') elif k in allowed_keys: parameters[k] = val else: raise ValueError(f'The "{k}" is not in allowed_keys: {allowed_keys}')
[docs] def configure_dataset(args, data_yaml_path=None): """Configure dataset args.""" # Create instances of Task, ConfigurableParameters and Dataset. data_subset_format = {"ann-files": None, "data-roots": None} data_config = {"data": {subset: data_subset_format.copy() for subset in ("train", "val", "test")}} data_config["data"]["unlabeled"] = {"file-list": None, "data-roots": None} if data_yaml_path and Path(data_yaml_path).exists(): with open(Path(data_yaml_path), "r", encoding="UTF-8") as stream: data_config = yaml.safe_load(stream) # The command's args are overridden and use first if "train_ann_files" in args and args.train_ann_files: data_config["data"]["train"]["ann-files"] = str(Path(args.train_ann_files).absolute()) if "train_data_roots" in args and args.train_data_roots: data_config["data"]["train"]["data-roots"] = str(Path(args.train_data_roots).absolute()) if "val_ann_files" in args and args.val_ann_files: data_config["data"]["val"]["ann-files"] = str(Path(args.val_ann_files).absolute()) if "val_data_roots" in args and args.val_data_roots: data_config["data"]["val"]["data-roots"] = str(Path(args.val_data_roots).absolute()) if "unlabeled_file_list" in args and args.unlabeled_file_list: data_config["data"]["unlabeled"]["file-list"] = str(Path(args.unlabeled_file_list).absolute()) if "unlabeled_data_roots" in args and args.unlabeled_data_roots: data_config["data"]["unlabeled"]["data-roots"] = str(Path(args.unlabeled_data_roots).absolute()) if "test_ann_files" in args and args.test_ann_files: data_config["data"]["test"]["ann-files"] = str(Path(args.test_ann_files).absolute()) if "test_data_roots" in args and args.test_data_roots: data_config["data"]["test"]["data-roots"] = str(Path(args.test_data_roots).absolute()) return data_config