Source code for otx.api.entities.task_environment

"""This module implements the TaskEnvironment entity."""

# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import List, Optional, Type, TypeVar

from otx.api.configuration import ConfigurableParameters, cfg_helper
from otx.api.entities.label import LabelEntity
from otx.api.entities.label_schema import LabelSchemaEntity
from otx.api.entities.model import ModelConfiguration, ModelEntity
from otx.api.entities.model_template import ModelTemplate

TypeVariable = TypeVar("TypeVariable", bound=ConfigurableParameters)


# pylint: disable=too-many-instance-attributes; Requires refactor
[docs] class TaskEnvironment: """Defines the machine learning environment the task runs in. Args: model_template (ModelTemplate): The model template used for this task model (Optional[ModelEntity]): Model to use; if not specified, the task must be either weight-less or use pre-trained or randomly initialised weights. hyper_parameters (ConfigurableParameters): Set of hyper parameters label_schema (LabelSchemaEntity): Label schema associated to this task """ def __init__( self, model_template: ModelTemplate, model: Optional[ModelEntity], hyper_parameters: ConfigurableParameters, label_schema: LabelSchemaEntity, ): self.model_template = model_template self.model = model self.__hyper_parameters = hyper_parameters self.label_schema = label_schema def __repr__(self): """String representation of the TaskEnvironment object.""" return ( f"TaskEnvironment(model={self.model}, label_schema={self.label_schema}, " f"hyper_params={self.__hyper_parameters})" ) def __eq__(self, other: object) -> bool: """Compares two TaskEnvironment objects. Args: other (TaskEnvironment): Object to compare with. Returns: bool: True if equal, False otherwise. """ if isinstance(other, TaskEnvironment): return ( self.model == other.model and self.label_schema == other.label_schema # TODO get_hyperparameters should return Union rather than TypeVariable and self.get_hyper_parameters(instance_of=None) # type: ignore == other.get_hyper_parameters(instance_of=None) ) return False
[docs] def get_labels(self, include_empty: bool = False) -> List[LabelEntity]: """Return the labels in this task environment (based on the label schema). Args: include_empty (bool): Include the empty label if ``True``. Defaults to False. Returns: List[LabelEntity]: List of labels """ return self.label_schema.get_labels(include_empty)
[docs] def get_hyper_parameters(self, instance_of: Optional[Type[TypeVariable]] = None) -> TypeVariable: """Returns Configuration for the task, de-serialized as type specified in `instance_of`. If the type of the configurable parameters is unknown, a generic ConfigurableParameters object with all available parameters can be obtained by calling method with instance_of = None. Example: >>> self.get_hyper_parameters(instance_of=TorchSegmentationConfig) TorchSegmentationConfig() Args: instance_of (Optional[Type[TypeVariable]]): subtype of ModelConfig of the hyperparamters. Defaults to None. Returns: TypeVariable: ConfigurableParameters entity """ if instance_of is None: # If the instance_of is None, the type variable is not defined so the # return type won't be deduced correctly return self.__hyper_parameters # type: ignore # Otherwise, update the base config according to what is stored in the repo. base_config = instance_of(header=self.__hyper_parameters.header) cfg_helper.substitute_values(base_config, value_input=self.__hyper_parameters, allow_missing_values=True) return base_config
[docs] def set_hyper_parameters(self, hyper_parameters: ConfigurableParameters): """Sets the hyper parameters for the task. Example: >>> self.set_hyper_parameters(hyper_parameters=TorchSegmentationParameters()) None Args: hyper_parameters (ConfigurationParameter): ConfigurableParameters entity to assign to task """ if not isinstance(hyper_parameters, ConfigurableParameters): raise ValueError(f"Unable to set hyper parameters, invalid input: {hyper_parameters}") self.__hyper_parameters = hyper_parameters
[docs] def get_model_configuration(self) -> ModelConfiguration: """Get the configuration needed to use the current model. That is the current set of: * configurable parameters * labels * label schema Returns: ModelConfiguration: Model configuration """ return ModelConfiguration(self.__hyper_parameters, self.label_schema)