Source code for otx.algorithms.action.tools.sample_classification
"""Sample Code of otx training for action classification."""
# Copyright (C) 2021-2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
# ruff: noqa: E402
import argparse
import os
import sys
os.environ["FEATURE_FLAGS_OTX_ACTION_TASKS"] = "1"
from otx.algorithms.common.utils import get_task_class
from otx.api.configuration.helper import create
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import ModelEntity
from otx.api.entities.model_template import parse_model_template
from otx.api.entities.optimization_parameters import OptimizationParameters
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.api.usecases.tasks.interfaces.optimization_interface import OptimizationType
from otx.core.data.adapter import get_dataset_adapter
from otx.utils.logger import get_logger
logger = get_logger()
[docs]
def parse_args():
"""Parse function for getting model template & check export."""
parser = argparse.ArgumentParser(description="Sample showcasing the new API")
parser.add_argument("template_file_path", help="path to template file")
parser.add_argument("--export", action="store_true")
return parser.parse_args()
TRAIN_DATA_ROOTS = "tests/assets/cvat_dataset/action_classification/train"
VAL_DATA_ROOTS = "tests/assets/cvat_dataset/action_classification/train"
[docs]
def load_test_dataset(model_template):
"""Load Sample dataset for detection."""
algo_backend = model_template.hyper_parameters.parameter_overrides["algo_backend"]
train_type = algo_backend["train_type"]["default_value"]
dataset_adapter = get_dataset_adapter(
model_template.task_type,
train_type,
train_data_roots=TRAIN_DATA_ROOTS,
val_data_roots=VAL_DATA_ROOTS,
)
dataset = dataset_adapter.get_otx_dataset()
label_schema = dataset_adapter.get_label_schema()
return dataset, label_schema
# pylint: disable=too-many-locals, too-many-statements
[docs]
def main(args):
"""Main function of Detection Sample."""
logger.info("Fine tuning sample dataset")
logger.info("Sample dataset can be found at tests/assets/cvat_dataset/action_classification")
logger.info("Load model template")
model_template = parse_model_template(args.template_file_path)
logger.info("Get dataset")
dataset, labels_schema = load_test_dataset(model_template)
logger.info("Set hyperparameters")
params = create(model_template.hyper_parameters.data)
params.learning_parameters.num_iters = 1
logger.info("Setup environment")
environment = TaskEnvironment(
model=None,
hyper_parameters=params,
label_schema=labels_schema,
model_template=model_template,
)
logger.info("Create base Task")
task_impl_path = model_template.entrypoints.base
task_cls = get_task_class(task_impl_path)
task = task_cls(task_environment=environment)
logger.info("Train model")
output_model = ModelEntity(
dataset,
environment.get_model_configuration(),
)
task.train(dataset, output_model)
logger.info("Get predictions on the validation set")
validation_dataset = dataset.get_subset(Subset.VALIDATION)
predicted_validation_dataset = task.infer(
validation_dataset.with_empty_annotations(),
InferenceParameters(is_evaluation=False),
)
resultset = ResultSetEntity(
model=output_model,
ground_truth_dataset=validation_dataset,
prediction_dataset=predicted_validation_dataset,
)
logger.info("Estimate quality on validation set")
task.evaluate(resultset)
logger.info(str(resultset.performance))
if args.export:
logger.info("Export model")
exported_model = ModelEntity(
dataset,
environment.get_model_configuration(),
)
task.export(ExportType.OPENVINO, exported_model, dump_features=False)
logger.info("Create OpenVINO Task")
environment.model = exported_model
openvino_task_impl_path = model_template.entrypoints.openvino
openvino_task_cls = get_task_class(openvino_task_impl_path)
openvino_task = openvino_task_cls(environment)
logger.info("Get predictions on the validation set")
predicted_validation_dataset = openvino_task.infer(
validation_dataset.with_empty_annotations(),
InferenceParameters(is_evaluation=True),
)
resultset = ResultSetEntity(
model=output_model,
ground_truth_dataset=validation_dataset,
prediction_dataset=predicted_validation_dataset,
)
logger.info("Estimate quality on validation set")
openvino_task.evaluate(resultset)
logger.info(str(resultset.performance))
logger.info("Run POT optimization")
optimized_model = ModelEntity(
dataset,
environment.get_model_configuration(),
)
openvino_task.optimize(
OptimizationType.POT,
dataset.get_subset(Subset.TRAINING),
optimized_model,
OptimizationParameters(),
)
logger.info("Get predictions on the validation set")
predicted_validation_dataset = openvino_task.infer(
validation_dataset.with_empty_annotations(),
InferenceParameters(is_evaluation=True),
)
resultset = ResultSetEntity(
model=optimized_model,
ground_truth_dataset=validation_dataset,
prediction_dataset=predicted_validation_dataset,
)
logger.info("Performance of optimized model:")
openvino_task.evaluate(resultset)
logger.info(str(resultset.performance))
if __name__ == "__main__":
sys.exit(main(parse_args()) or 0)