Source code for otx.algorithms.classification.tools.classification_sample

"""Sample code of otx training for classification."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=too-many-locals, invalid-name, too-many-statements

import argparse
import random
import sys

import numpy as np
import torch

from otx.algorithms.common.utils import get_task_class
from otx.api.configuration.helper import create
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.label import Domain
from otx.api.entities.label_schema import (
    LabelEntity,
    LabelGroup,
    LabelGroupType,
    LabelSchemaEntity,
)
from otx.api.entities.model import ModelEntity
from otx.api.entities.model_template import parse_model_template
from otx.api.entities.optimization_parameters import OptimizationParameters
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.api.usecases.tasks.interfaces.optimization_interface import OptimizationType
from otx.utils.logger import get_logger

SEED = 5
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

logger = get_logger()


[docs] def parse_args(): """Parse function for getting model template & check export.""" parser = argparse.ArgumentParser(description="Sample showcasing the new API") parser.add_argument("template_file_path", help="path to template file") parser.add_argument("--export", action="store_true") parser.add_argument("--multilabel", action="store_true") parser.add_argument("--hierarchical", action="store_true") return parser.parse_args()
[docs] def load_test_dataset(data_type, args): """Load test dataset.""" import PIL from PIL import ImageDraw from otx.api.entities.annotation import ( Annotation, AnnotationSceneEntity, AnnotationSceneKind, ) from otx.api.entities.dataset_item import DatasetItemEntity from otx.api.entities.image import Image from otx.api.entities.scored_label import ScoredLabel from otx.api.entities.shapes.rectangle import Rectangle def gen_image(resolution, shape=None): image = PIL.Image.new("RGB", resolution, (255, 255, 255)) draw = ImageDraw.Draw(image) h, w = image.size shape = shape.split("+") if "+" in shape else [shape] for s in shape: if s == "rectangle": draw.rectangle((h * 0.1, w * 0.1, h * 0.4, w * 0.4), fill=(0, 192, 192), outline=(0, 0, 0)) if s == "triangle": draw.polygon( ((h * 0.5, w * 0.25), (h, w * 0.25), (h * 0.8, w * 0.5)), fill=(255, 255, 0), outline=(0, 0, 0) ) if s == "pieslice": draw.pieslice( ((h * 0.1, w * 0.5), (h * 0.5, w * 0.9)), start=50, end=250, fill=(0, 255, 0), outline=(0, 0, 0) ) if s == "circle": draw.ellipse((h * 0.5, w * 0.5, h * 0.9, w * 0.9), fill="blue", outline="blue") if s == "text": draw.text((0, 0), "Intel", fill="blue", align="center") return np.array(image), shape datas = [ gen_image((32, 32), shape="rectangle"), gen_image((32, 32), shape="triangle"), gen_image((32, 32), shape="rectangle+triangle"), # for multilabel (old) gen_image((32, 32), shape="pieslice"), gen_image((32, 32), shape="pieslice+rectangle"), gen_image((32, 32), shape="pieslice+triangle"), gen_image((32, 32), shape="pieslice+rectangle+triangle"), # for multilabel (new) gen_image((32, 32), shape="circle"), gen_image((32, 32), shape="circle+text"), # for hierarchical (new) ] labels = { "rectangle": LabelEntity(name="rectangle", domain=Domain.CLASSIFICATION, id=0), "triangle": LabelEntity(name="triangle", domain=Domain.CLASSIFICATION, id=1), "pieslice": LabelEntity(name="pieslice", domain=Domain.CLASSIFICATION, id=2), "circle": LabelEntity(name="circle", domain=Domain.CLASSIFICATION, id=3), "text": LabelEntity(name="text", domain=Domain.CLASSIFICATION, id=4), } def get_image(i, subset, ignored_labels=None): image, shape = datas[i] lbl = [ScoredLabel(label=labels[s], probability=1.0) for s in shape] return DatasetItemEntity( media=Image(data=image), annotation_scene=AnnotationSceneEntity( annotations=[ Annotation( Rectangle(x1=0.0, y1=0.0, x2=1.0, y2=1.0), labels=lbl, ) ], kind=AnnotationSceneKind.ANNOTATION, ), subset=subset, ignored_labels=ignored_labels, ) def gen_old_new_dataset(multilabel=False, hierarchical=False): old_train, old_val, new_train, new_val = [], [], [], [] old_repeat = 8 new_repeat = 4 if multilabel: old_img_idx = [0, 1, 2] new_img_idx = [0, 1, 2, 3, 4, 5, 6] ignored_labels = [labels["pieslice"]] elif hierarchical: old_img_idx = [0, 1, 2, 3, 8] new_img_idx = [0, 1, 2, 3, 8, -1] ignored_labels = [labels["text"]] else: old_img_idx = [0, 1] new_img_idx = [0, 1, 3] for _ in range(old_repeat): for idx in old_img_idx: old_train.append(get_image(idx, Subset.TRAINING)) old_val.append(get_image(idx, Subset.VALIDATION)) for _ in range(new_repeat): for idx in new_img_idx: if multilabel: new_train.append(get_image(idx, Subset.TRAINING, ignored_labels=ignored_labels)) elif hierarchical: new_train.append(get_image(idx, Subset.TRAINING, ignored_labels=ignored_labels)) else: new_train.append(get_image(idx, Subset.TRAINING)) new_val.append(get_image(idx, Subset.VALIDATION)) return old_train + old_val, new_train + new_val old, new = gen_old_new_dataset(args.multilabel, args.hierarchical) if not args.hierarchical: labels = [labels["rectangle"], labels["triangle"], labels["pieslice"]] else: labels = list(labels.values()) if data_type == "old": return DatasetEntity(old), labels[:-1] return DatasetEntity(old + new), labels
[docs] def get_label_schema(labels, multilabel=False, hierarchical=False): """Get label schema.""" label_schema = LabelSchemaEntity() emptylabel = LabelEntity(id=-1, name="Empty label", is_empty=True, domain=Domain.CLASSIFICATION) empty_group = LabelGroup(name="empty", labels=[emptylabel], group_type=LabelGroupType.EMPTY_LABEL) if multilabel: for label in labels: label_schema.add_group(LabelGroup(name=label.name, labels=[label], group_type=LabelGroupType.EXCLUSIVE)) label_schema.add_group(empty_group) elif hierarchical: single_label_classes = ["pieslice", "circle"] multi_label_classes = ["rectangle", "triangle", "text"] single_labels = [label for label in labels if label.name in single_label_classes] single_label_group = LabelGroup(name="labels", labels=single_labels, group_type=LabelGroupType.EXCLUSIVE) label_schema.add_group(single_label_group) for label in labels: if label.name in multi_label_classes: label_schema.add_group( LabelGroup( name=f"{label.name}____{label.name}_group", labels=[label], group_type=LabelGroupType.EXCLUSIVE ) ) else: main_group = LabelGroup(name="labels", labels=labels, group_type=LabelGroupType.EXCLUSIVE) label_schema.add_group(main_group) return label_schema
[docs] def validate(task, validation_dataset, model): """Validate.""" print("Get predictions on the validation set") predicted_validation_dataset = task.infer( validation_dataset.with_empty_annotations(), InferenceParameters(is_evaluation=True) ) resultset = ResultSetEntity( model=model, ground_truth_dataset=validation_dataset, prediction_dataset=predicted_validation_dataset, ) print("Estimate quality on validation set") task.evaluate(resultset) print(str(resultset.performance))
[docs] def main(args): """Main of Classification Sample Test.""" logger.info("Train initial model with OLD dataset") dataset, labels_list = load_test_dataset("old", args) labels_schema = get_label_schema(labels_list, multilabel=args.multilabel, hierarchical=args.hierarchical) logger.info(f"Train dataset: {len(dataset.get_subset(Subset.TRAINING))} items") logger.info(f"Validation dataset: {len(dataset.get_subset(Subset.VALIDATION))} items") logger.info("Load model template") model_template = parse_model_template(args.template_file_path) logger.info("Set hyperparameters") params = create(model_template.hyper_parameters.data) params.learning_parameters.num_iters = 10 params.learning_parameters.learning_rate = 0.03 params.learning_parameters.learning_rate_warmup_iters = 4 params.learning_parameters.batch_size = 16 logger.info("Setup environment") environment = TaskEnvironment( model=None, hyper_parameters=params, label_schema=labels_schema, model_template=model_template ) logger.info("Create base Task") task_impl_path = model_template.entrypoints.base task_cls = get_task_class(task_impl_path) task = task_cls(task_environment=environment) logger.info("Train model") initial_model = ModelEntity( dataset, environment.get_model_configuration(), ) task.train(dataset, initial_model) logger.info("Class-incremental learning with OLD + NEW dataset") dataset, labels_list = load_test_dataset("new") labels_schema = get_label_schema(labels_list, multilabel=args.multilabel, hierarchical=args.hierarchical) logger.info(f"Train dataset: {len(dataset.get_subset(Subset.TRAINING))} items") logger.info(f"Validation dataset: {len(dataset.get_subset(Subset.VALIDATION))} items") logger.info("Load model template") model_template = parse_model_template(args.template_file_path) logger.info("Set hyperparameters") params = create(model_template.hyper_parameters.data) params.learning_parameters.num_iters = 10 params.learning_parameters.learning_rate = 0.03 params.learning_parameters.learning_rate_warmup_iters = 4 params.learning_parameters.batch_size = 16 logger.info("Setup environment") environment = TaskEnvironment( model=initial_model, hyper_parameters=params, label_schema=labels_schema, model_template=model_template ) logger.info("Create base Task") task_impl_path = model_template.entrypoints.base task_cls = get_task_class(task_impl_path) task = task_cls(task_environment=environment) logger.info("Train model") output_model = ModelEntity( dataset, environment.get_model_configuration(), ) task.train(dataset, output_model) logger.info("Get predictions on the validation set") validation_dataset = dataset.get_subset(Subset.VALIDATION) predicted_validation_dataset = task.infer( validation_dataset.with_empty_annotations(), InferenceParameters(is_evaluation=True) ) resultset = ResultSetEntity( model=output_model, ground_truth_dataset=validation_dataset, prediction_dataset=predicted_validation_dataset, ) logger.info("Estimate quality on validation set") task.evaluate(resultset) logger.info(str(resultset.performance)) if args.export or args.multilabel or args.hierarchical: logger.info("Export model") exported_model = ModelEntity( dataset, environment.get_model_configuration(), ) task.export(ExportType.OPENVINO, exported_model) logger.info("Create OpenVINO Task") environment.model = exported_model openvino_task_impl_path = model_template.entrypoints.openvino openvino_task_cls = get_task_class(openvino_task_impl_path) openvino_task = openvino_task_cls(environment) logger.info("Get predictions on the validation set") predicted_validation_dataset = openvino_task.infer( validation_dataset.with_empty_annotations(), InferenceParameters(is_evaluation=True) ) resultset = ResultSetEntity( model=exported_model, ground_truth_dataset=validation_dataset, prediction_dataset=predicted_validation_dataset, ) logger.info("Estimate quality on validation set") openvino_task.evaluate(resultset) logger.info(str(resultset.performance)) # POT test logger.info("Run POT optimization") optimized_model = ModelEntity( dataset, environment.get_model_configuration(), ) openvino_task.optimize(OptimizationType.POT, dataset, optimized_model, OptimizationParameters()) logger.info("Run POT deploy") openvino_task.deploy(optimized_model) validate(task, validation_dataset, optimized_model) # NNCF test task_impl_path = model_template.entrypoints.nncf nncf_task_cls = get_task_class(task_impl_path) print("Create NNCF Task") environment.model = output_model task = nncf_task_cls(task_environment=environment) print("Optimize model by NNCF") optimized_model = ModelEntity( dataset, environment.get_model_configuration(), ) task.optimize(OptimizationType.NNCF, dataset, optimized_model, None) validate(task, validation_dataset, output_model)
if __name__ == "__main__": sys.exit(main(parse_args()) or 0)