Source code for datumaro.plugins.explorer

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from typing import List, Sequence

import numpy as np
from tokenizers import Tokenizer

from datumaro.components.annotation import Annotation, HashKey
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.errors import MediaTypeError
from datumaro.components.media import Image
from datumaro.plugins.openvino_plugin.launcher import OpenvinoLauncher


[docs] class ExplorerLauncher(OpenvinoLauncher): def __init__( self, description=None, weights=None, interpreter=None, model_dir=None, model_name=None, output_layers=None, device=None, ): super().__init__( description, weights, interpreter, model_dir, model_name, output_layers, device ) self._device = device or "cpu" self._output_blobs = next(iter(self._net.outputs)) self._tokenizer = None def _tokenize(self, texts: str, context_length: int = 77, truncate: bool = True): if not self._tokenizer: checkpoint = "openai/clip-vit-base-patch32" self._tokenizer = Tokenizer.from_pretrained(checkpoint) tokens = self._tokenizer.encode(texts).ids result = np.zeros((1, context_length)) if len(tokens) > context_length: if truncate: eot_token = tokens.ids[-1] tokens = tokens[:context_length] tokens[-1] = eot_token for i, token in enumerate(tokens): result[:, i] = token return result def _compute_hash(self, features): features = np.sign(features) hash_key = np.clip(features, 0, None) hash_key = hash_key.astype(np.uint8) hash_key = np.packbits(hash_key, axis=-1) return hash_key
[docs] def infer_text(self, text: str, use_prompt: bool = True) -> HashKey: prompt_text = f"a photo of a {text}" if use_prompt else text inputs = self._tokenize(prompt_text) preds = self.infer(inputs) anns = self.postprocess(preds[0], None) return anns[0]
[docs] def infer_item(self, item: DatasetItem) -> HashKey: anns = self.launch([item])[0] return anns[0]
[docs] def launch(self, batch: Sequence[DatasetItem]) -> List[List[Annotation]]: outputs = super().launch(batch) return outputs
[docs] def type_check(self, item): if not isinstance(item.media, Image): raise MediaTypeError(f"Media type should be Image, Current type={type(item.media)}") return True