Source code for datumaro.plugins.data_formats.brats

# Copyright (C) 2022-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import errno
import glob
import os.path as osp
from typing import List, Optional

import nibabel as nib
import numpy as np

from datumaro.components.annotation import AnnotationType, ExtractedMask, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.media import MultiframeImage


[docs] class BratsPath: IMAGES_DIR = "images" LABELS = "labels" DATA_EXT = ".nii.gz"
[docs] class BratsBase(SubsetBase): def __init__(self, path: str, *, ctx: Optional[ImportContext] = None): if not osp.isdir(path): raise NotADirectoryError(errno.ENOTDIR, "Can't find dataset directory", path) self._subset_suffix = osp.basename(path)[len(BratsPath.IMAGES_DIR) :] subset = None if self._subset_suffix == "Tr": subset = "train" elif self._subset_suffix == "Ts": subset = "test" super().__init__(subset=subset, media_type=MultiframeImage, ctx=ctx) self._root_dir = osp.dirname(path) self._categories = self._load_categories() self._items = list(self._load_items(path).values()) def _load_categories(self): label_cat = LabelCategories() labels_path = osp.join(self._root_dir, BratsPath.LABELS) if osp.isfile(labels_path): with open(labels_path, encoding="utf-8") as f: for line in f: label_cat.add(line.strip()) return {AnnotationType.label: label_cat} def _load_items(self, path): items = {} for image_path in glob.glob(osp.join(path, f"*{BratsPath.DATA_EXT}")): data = nib.load(image_path).get_fdata() item_id = osp.basename(image_path)[: -len(BratsPath.DATA_EXT)] images = [0] * data.shape[2] for i in range(data.shape[2]): images[i] = data[:, :, i] items[item_id] = DatasetItem( id=item_id, subset=self._subset, media=MultiframeImage(images, path=image_path) ) masks_dir = osp.join(self._root_dir, BratsPath.LABELS + self._subset_suffix) for mask in glob.glob(osp.join(masks_dir, f"*{BratsPath.DATA_EXT}")): data = nib.load(mask).get_fdata() item_id = osp.basename(image_path)[: -len(BratsPath.DATA_EXT)] if item_id not in items: items[item_id] = DatasetItem(id=item_id) anno = [] for i in range(data.shape[2]): np_mask = data[:, :, i] classes = np.unique(np_mask) for class_id in classes: anno.append( ExtractedMask( index_mask=np_mask, index=class_id, label=class_id, attributes={"image_id": i}, ) ) self._ann_types.add(AnnotationType.mask) items[item_id].annotations = anno return items
[docs] class BratsImporter(Importer):
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> None: with context.require_any(): with context.alternative(): context.require_file(f"*/*{BratsPath.DATA_EXT}")
[docs] @classmethod def find_sources(cls, path): return cls._find_sources_recursive(path, "", "brats", filename=f"{BratsPath.IMAGES_DIR}*")
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [osp.splitext(BratsPath.DATA_EXT)[-1]]