Source code for otx.cli.tools.export

"""Model exporting tool."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from pathlib import Path

# Update environment variables for CLI use
import otx.cli  # noqa: F401
from otx.api.entities.model import ModelEntity, ModelOptimizationType, ModelPrecision
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.usecases.adapters.model_adapter import ModelAdapter
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.cli.manager import ConfigManager
from otx.cli.utils.importing import get_impl_class
from otx.cli.utils.io import read_binary, read_label_schema, save_model_data
from otx.cli.utils.nncf import is_checkpoint_nncf
from otx.cli.utils.parser import add_hyper_parameters_sub_parser, get_override_param, get_parser_and_hprams_data
from otx.utils.logger import config_logger


[docs] def get_args(): """Parses command line arguments.""" parser, hyper_parameters, params = get_parser_and_hprams_data() parser.add_argument( "--load-weights", help="Load model weights from previously saved checkpoint.", ) parser.add_argument( "-o", "--output", help="Location where exported model will be stored.", ) parser.add_argument( "--workspace", help="Path to the workspace where the command will run.", default=None, ) parser.add_argument( "--dump-features", action="store_true", help="Whether to return feature vector and saliency map for explanation purposes.", ) parser.add_argument( "--half-precision", action="store_true", help="This flag indicated if model is exported in half precision (FP16).", ) parser.add_argument( "--export-type", help="Type of the resulting model (OpenVINO or ONNX).", default="openvino", ) add_hyper_parameters_sub_parser(parser, hyper_parameters, modes=("INFERENCE",)) override_param = get_override_param(params) return parser.parse_args(), override_param
[docs] def main(): """Main function that is used for model exporting.""" args, override_param = get_args() config_manager = ConfigManager(args, mode="export", workspace_root=args.workspace) config_logger(config_manager.output_path / "otx.log", "INFO") # Auto-Configuration for model template config_manager.configure_template() # Load template.yaml file. template = config_manager.template # Get class for Task. if not args.load_weights and config_manager.check_workspace(): latest_model_path = ( config_manager.workspace_root / "outputs" / "latest_trained_model" / "models" / "weights.pth" ) args.load_weights = str(latest_model_path) is_nncf = is_checkpoint_nncf(args.load_weights) task_class = get_impl_class(template.entrypoints.nncf if is_nncf else template.entrypoints.base) # Get hyper parameters schema. hyper_parameters = config_manager.get_hyparams_config(override_param) assert hyper_parameters environment = TaskEnvironment( model=None, hyper_parameters=hyper_parameters, label_schema=read_label_schema(args.load_weights), model_template=template, ) model_adapters = {"weights.pth": ModelAdapter(read_binary(args.load_weights))} model = ModelEntity( configuration=environment.get_model_configuration(), model_adapters=model_adapters, train_dataset=None, optimization_type=ModelOptimizationType.NNCF if is_nncf else ModelOptimizationType.NONE, ) environment.model = model (config_manager.output_path / "logs").mkdir(exist_ok=True, parents=True) task = task_class(task_environment=environment, output_path=str(config_manager.output_path / "logs")) exported_model = ModelEntity(None, environment.get_model_configuration()) export_precision = ModelPrecision.FP16 if args.half_precision else ModelPrecision.FP32 if args.export_type.lower() not in ["openvino", "onnx"]: raise ValueError("Unsupported export type") export_type = ExportType.OPENVINO if "openvino" == args.export_type.lower() else ExportType.ONNX task.export(export_type, exported_model, export_precision, args.dump_features) if not args.output: output_path = config_manager.output_path output_path = output_path / "openvino" else: output_path = Path(args.output) output_path.mkdir(exist_ok=True, parents=True) save_model_data(exported_model, str(output_path)) return dict(retcode=0, template=template.name)
if __name__ == "__main__": main()