Source code for datumaro.plugins.data_formats.brats_numpy
# Copyright (C) 2022-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import errno
import os.path as osp
from typing import List, Optional
import numpy as np
from datumaro.components.annotation import AnnotationType, Cuboid3d, ExtractedMask, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.format_detection import FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.media import MultiframeImage
from datumaro.util.pickle_util import PickleLoader
[docs]
class BratsNumpyPath:
IDS_FILE = "val_ids.p"
BOXES_FILE = "val_brain_bbox.p"
LABELS_FILE = "labels"
DATA_SUFFIX = "_data_cropped"
LABEL_SUFFIX = "_label_cropped"
[docs]
class BratsNumpyBase(SubsetBase):
def __init__(
self,
path: str,
*,
subset: Optional[str] = None,
ctx: Optional[ImportContext] = None,
):
if not osp.isfile(path):
raise FileNotFoundError(errno.ENOENT, "Can't find annotations file", path)
super().__init__(subset=subset, media_type=MultiframeImage, ctx=ctx)
self._root_dir = osp.dirname(path)
self._categories = self._load_categories()
self._items = list(self._load_items(path).values())
def _load_categories(self):
label_cat = LabelCategories()
labels_path = osp.join(self._root_dir, BratsNumpyPath.LABELS_FILE)
if osp.isfile(labels_path):
with open(labels_path, encoding="utf-8") as f:
for line in f:
label_cat.add(line.strip())
return {AnnotationType.label: label_cat}
def _load_items(self, path):
items = {}
with open(path, "rb") as f:
ids = PickleLoader.restricted_load(f)
boxes = None
boxes_file = osp.join(self._root_dir, BratsNumpyPath.BOXES_FILE)
if osp.isfile(boxes_file):
with open(boxes_file, "rb") as f:
boxes = PickleLoader.restricted_load(f)
# TODO(vinnamki): Apply lazy loading for images and masks
for i, item_id in enumerate(ids):
image_path = osp.join(self._root_dir, item_id + BratsNumpyPath.DATA_SUFFIX + ".npy")
media = None
if osp.isfile(image_path):
data = np.load(image_path)[0].transpose()
images = [0] * data.shape[2]
for j in range(data.shape[2]):
images[j] = data[:, :, j]
media = MultiframeImage(images, path=image_path)
anno = []
mask_path = osp.join(self._root_dir, item_id + BratsNumpyPath.LABEL_SUFFIX + ".npy")
if osp.isfile(mask_path):
mask = np.load(mask_path)[0].transpose()
for j in range(mask.shape[2]):
np_mask = mask[:, :, j]
classes = np.unique(np_mask)
for class_id in classes:
anno.append(
ExtractedMask(
index_mask=np_mask,
index=class_id,
label=class_id,
attributes={"image_id": j},
)
)
self._ann_types.add(AnnotationType.mask)
if boxes is not None:
box = boxes[i]
anno.append(Cuboid3d(position=list(box[0]), rotation=list(box[1])))
self._ann_types.add(AnnotationType.cuboid_3d)
items[item_id] = DatasetItem(id=item_id, media=media, annotations=anno)
return items
[docs]
class BratsNumpyImporter(Importer):
[docs]
@classmethod
def detect(cls, context: FormatDetectionContext) -> None:
context.require_file(BratsNumpyPath.IDS_FILE)
[docs]
@classmethod
def find_sources(cls, path):
return cls._find_sources_recursive(
path, "", "brats_numpy", filename=BratsNumpyPath.IDS_FILE
)
[docs]
@classmethod
def get_file_extensions(cls) -> List[str]:
return [osp.splitext(BratsNumpyPath.IDS_FILE)[1]]