Source code for datumaro.components.abstracts.model_interpreter
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, TypeVar, Union
import numpy as np
from datumaro.components.annotation import Annotation
from datumaro.components.dataset_base import DatasetItem
__all__ = ["IModelInterpreter", "PrepInfo", "ModelPred", "LauncherInputType"]
PrepInfo = TypeVar("PrepInfo")
ModelPred = Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]
LauncherInputType = Union[np.ndarray, Dict[str, np.ndarray]]
[docs]
class IModelInterpreter(ABC):
[docs]
@abstractmethod
def preprocess(self, inp: DatasetItem) -> Tuple[LauncherInputType, PrepInfo]:
"""Preprocessing an dataset item input.
Parameters:
img: DatasetItem input
Returns:
It returns a tuple of preprocessed input and preprocessing information.
The preprocessing information will be used in the postprocessing step.
One use case for this would be an affine transformation of the output bounding box
obtained by object detection models. Input images for those models are usually resized
to fit the model input dimensions. As a result, the postprocessing step should refine
the model output bounding box to match the original input image size.
"""
raise NotImplementedError("Function should be implemented.")
[docs]
@abstractmethod
def postprocess(self, pred: ModelPred, info: PrepInfo) -> List[Annotation]:
"""Postprocess a model prediction.
Parameters:
pred: Model prediction
info: Preprocessing information coming from the preprocessing step
Returns:
A list of annotations which is created from the model predictions
"""
raise NotImplementedError("Function should be implemented.")
[docs]
@abstractmethod
def get_categories(self):
"""It should be implemented if the model generate a new categories"""
raise NotImplementedError("Function should be implemented.")