Source code for datumaro.plugins.inference_server_plugin.ovms
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import logging as log
from typing import List, Union
import numpy as np
from ovmsclient import make_grpc_client, make_http_client
from ovmsclient.tfs_compat.grpc.serving_client import GrpcClient
from ovmsclient.tfs_compat.http.serving_client import HttpClient
from datumaro.components.abstracts.model_interpreter import LauncherInputType, ModelPred
from datumaro.components.errors import DatumaroError
from datumaro.plugins.inference_server_plugin.base import (
LauncherForDedicatedInferenceServer,
ProtocolType,
)
__all__ = ["OVMSLauncher"]
TClient = Union[GrpcClient, HttpClient]
[docs]
class OVMSLauncher(LauncherForDedicatedInferenceServer[TClient]):
"""Inference launcher for OVMS (OpenVINO™ Model Server) (https://github.com/openvinotoolkit/model_server)
Parameters:
model_name: Name of the model. It should match with the model name loaded in the server instance.
model_interpreter_path: Python source code path which implements a model interpreter.
The model interpreter implement pre-processing of the model input and post-processing of the model output.
model_version: Version of the model loaded in the server instance
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
"""
def _init_client(self) -> TClient:
tls_config = self.tls_config.as_dict() if self.tls_config is not None else None
if self.protocol_type == ProtocolType.grpc:
return make_grpc_client(self.url, tls_config)
if self.protocol_type == ProtocolType.http:
return make_http_client(self.url, tls_config)
raise NotImplementedError(self.protocol_type)
def _check_server_health(self) -> None:
status = self._client.get_model_status(
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)
log.info(f"Health check succeeded: {status}")
def _init_metadata(self):
self._metadata = self._client.get_model_metadata(
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)
log.info(f"Received metadata: {self._metadata}")
[docs]
def infer(self, inputs: LauncherInputType) -> List[ModelPred]:
# Please see the following link for the input and output type of self._client.predict()
# https://github.com/openvinotoolkit/model_server/blob/releases/2022/3/client/python/ovmsclient/lib/docs/grpc_client.md#method-predict
# The input is Dict[str, np.ndarray].
# The output is Dict[str, np.ndarray] (If the model has multiple outputs),
# or np.ndarray (If the model has one single output).
pred_inputs = {self._input_key: inputs} if isinstance(inputs, np.ndarray) else inputs
results = self._client.predict(
inputs=pred_inputs,
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)
# If there is only one output key,
# it returns `np.ndarray`` rather than `Dict[str, np.ndarray]`.
# Please see ovmsclient.tfs_compat.grpc.responses.GrpcPredictResponse
if isinstance(results, np.ndarray):
results = {self._output_key: results}
outputs_group_by_item = [
{key: output for key, output in zip(results.keys(), outputs)}
for outputs in zip(*results.values())
]
return outputs_group_by_item
@property
def _input_key(self):
if hasattr(self, "__input_key"):
return self.__input_key
metadata_inputs = self._metadata.get("inputs")
if metadata_inputs is None:
raise DatumaroError("Cannot get metadata of the outputs.")
if len(metadata_inputs.keys()) > 1:
raise DatumaroError(
f"More than two model inputs are not allowed: {metadata_inputs.keys()}."
)
self.__input_key = next(iter(metadata_inputs.keys()))
return self.__input_key
@property
def _output_key(self):
if hasattr(self, "__output_key"):
return self.__output_key
metadata_outputs = self._metadata.get("outputs")
if metadata_outputs is None:
raise DatumaroError("Cannot get metadata of the outputs.")
if len(metadata_outputs.keys()) > 1:
raise DatumaroError(
f"More than two model outputs are not allowed: {metadata_outputs.keys()}."
)
self.__output_key = next(iter(metadata_outputs.keys()))
return self.__output_key