Source code for datumaro.plugins.data_formats.mvtec.exporter

# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import os.path as osp
from enum import Enum, auto

import cv2
import numpy as np

from datumaro.components.annotation import AnnotationType
from datumaro.components.errors import MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.media import Image
from datumaro.util import parse_str_enum_value

from .format import MvtecPath, MvtecTask


[docs] class LabelmapType(Enum): kitti = auto() source = auto()
[docs] class MvtecExporter(Exporter): DEFAULT_IMAGE_EXT = MvtecPath.IMAGE_EXT @staticmethod def _split_tasks_string(s): return [MvtecTask[i.strip().lower()] for i in s.split(",")]
[docs] @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) parser.add_argument( "--tasks", type=cls._split_tasks_string, help="MVTec task filter, comma-separated list of {%s} " "(default: all)" % ", ".join(t.name for t in MvtecTask), ) return parser
def __init__( self, extractor, save_dir, tasks=None, **kwargs, ): super().__init__(extractor, save_dir, **kwargs) assert tasks is None or isinstance(tasks, (MvtecTask, list, set)) if tasks is None: tasks = set(MvtecTask) elif isinstance(tasks, MvtecTask): tasks = {tasks} else: tasks = set(parse_str_enum_value(t, MvtecTask) for t in tasks) self._tasks = tasks def _apply_impl(self): if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): raise MediaTypeError("Media type is not an image") os.makedirs(self._save_dir, exist_ok=True) for subset_name, subset in self._extractor.subsets().items(): if MvtecTask.segmentation in self._tasks or MvtecTask.detection in self._tasks: os.makedirs(osp.join(self._save_dir, MvtecPath.MASK_DIR), exist_ok=True) for item in subset: labels = [] for ann in item.annotations: if ann.type in [AnnotationType.label, AnnotationType.mask, AnnotationType.bbox]: labels.append(self.get_label(ann.label)) if self._save_media: self._save_image(item, subdir=subset_name) bboxes = [a for a in item.annotations if a.type == AnnotationType.bbox] if bboxes and MvtecTask.detection in self._tasks: mask_path = osp.join( MvtecPath.MASK_DIR, item.id + MvtecPath.MASK_POSTFIX + MvtecPath.MASK_EXT ) mask = np.zeros((*item.media.size,), dtype=np.uint8) for bbox in bboxes: x, y, h, w = bbox.get_bbox() mask = cv2.rectangle( mask, np.array((x, y), dtype=np.int32), np.array((x + w, y + h), dtype=np.int32), (1), -1, ) if not osp.exists(osp.join(self._save_dir, osp.dirname(mask_path))): os.mkdir(osp.join(self._save_dir, osp.dirname(mask_path))) cv2.imwrite(osp.join(self._save_dir, mask_path), mask * 255) masks = [a for a in item.annotations if a.type == AnnotationType.mask] if masks and MvtecTask.segmentation in self._tasks: mask_path = osp.join( MvtecPath.MASK_DIR, item.id + MvtecPath.MASK_POSTFIX + MvtecPath.MASK_EXT ) if not osp.exists(osp.join(self._save_dir, osp.dirname(mask_path))): os.mkdir(osp.join(self._save_dir, osp.dirname(mask_path))) cv2.imwrite( osp.join(self._save_dir, mask_path), masks[0].image.astype(np.uint8) * 255 )
[docs] def get_label(self, label_id): return self._extractor.categories()[AnnotationType.label].items[label_id].name
[docs] class MvtecClassificationExporter(MvtecExporter): def __init__(self, *args, **kwargs): kwargs["tasks"] = MvtecTask.classification super().__init__(*args, **kwargs)
[docs] class MvtecSegmentationExporter(MvtecExporter): def __init__(self, *args, **kwargs): kwargs["tasks"] = MvtecTask.segmentation super().__init__(*args, **kwargs)
[docs] class MvtecDetectionExporter(MvtecExporter): def __init__(self, *args, **kwargs): kwargs["tasks"] = MvtecTask.detection super().__init__(*args, **kwargs)