Source code for otx.api.usecases.tasks.interfaces.inference_interface

"""This module contains the interface class for tasks."""


# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import abc
from typing import Dict

import numpy as np

from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.inference_parameters import InferenceParameters


[docs] class IInferenceTask(metaclass=abc.ABCMeta): """A base interface class for a task."""
[docs] @abc.abstractmethod def infer( self, dataset: DatasetEntity, inference_parameters: InferenceParameters, ) -> DatasetEntity: """This is the method that is called upon inference. This happens when the user wants to analyse a sample or multiple samples need to be analysed. Args: dataset: The input dataset to perform the analysis on. inference_parameters: The parameters to use for the analysis. Returns: The results of the analysis. """ raise NotImplementedError
[docs] class IRawInference(metaclass=abc.ABCMeta): """A base interface class for raw inference tasks."""
[docs] @abc.abstractmethod def raw_infer( self, input_tensors: Dict[str, np.ndarray], output_tensors: Dict[str, np.ndarray], ): """This is the method that is called to run a neural network over a set of tensors. This method takes as input/output the tensors which are directly fed to the neural network, and does not include any additional pre- and post-processing of the inputs and outputs. Args: input_tensors: Dictionary containing the input tensors. output_tensors: Dictionary to be filled by the task with the output tensors. """ raise NotImplementedError