Source code for datumaro.plugins.data_formats.mnist_csv

# Copyright (C) 2021-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import errno
import os
import os.path as osp
from typing import List, Optional

import numpy as np

from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.format_detection import FormatDetectionConfidence
from datumaro.components.importer import ImportContext, Importer
from datumaro.components.media import Image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


[docs] class MnistCsvPath: IMAGE_SIZE = 28 NONE_LABEL = -1
[docs] class MnistCsvBase(SubsetBase): def __init__( self, path: str, *, subset: Optional[str] = None, ctx: Optional[ImportContext] = None, ): if not osp.isfile(path): raise FileNotFoundError(errno.ENOENT, "Can't find annotations file", path) if not subset: file_name = osp.splitext(osp.basename(path))[0] subset = file_name.rsplit("_", maxsplit=1)[-1] super().__init__(subset=subset, ctx=ctx) self._dataset_dir = osp.dirname(path) self._categories = self._load_categories() self._items = list(self._load_items(path).values()) def _load_categories(self): if has_meta_file(self._dataset_dir): return { AnnotationType.label: LabelCategories.from_iterable( parse_meta_file(self._dataset_dir).keys() ) } label_cat = LabelCategories() labels_file = osp.join(self._dataset_dir, "labels.txt") if osp.isfile(labels_file): with open(labels_file, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue label_cat.add(line) else: for i in range(10): label_cat.add(str(i)) return {AnnotationType.label: label_cat} def _load_items(self, path): items = {} with open(path, "r", encoding="utf-8") as f: annotation_table = f.readlines() metafile = osp.join(self._dataset_dir, "meta_%s.csv" % self._subset) meta = [] if osp.isfile(metafile): with open(metafile, "r", encoding="utf-8") as f: meta = f.readlines() for i, data in enumerate(annotation_table): data = data.split(",") item_anno = [] try: label = int(data[0]) except ValueError: continue if label != MnistCsvPath.NONE_LABEL: item_anno.append(Label(label)) self._ann_types.add(AnnotationType.label) if 0 < len(meta): meta[i] = meta[i].strip().split(",") # support for single-channel image only image = None if 1 < len(data): if 0 < len(meta) and 1 < len(meta[i]): image = np.array([int(pix) for pix in data[1:]], dtype="uint8").reshape( int(meta[i][-2]), int(meta[i][-1]) ) else: image = np.array([int(pix) for pix in data[1:]], dtype="uint8").reshape(28, 28) if image is not None: image = Image.from_numpy(data=image) if 0 < len(meta) and len(meta[i]) in [1, 3]: i = meta[i][0] items[i] = DatasetItem(id=i, subset=self._subset, media=image, annotations=item_anno) return items
[docs] class MnistCsvImporter(Importer): DETECT_CONFIDENCE = FormatDetectionConfidence.MEDIUM _ANNO_EXT = ".csv"
[docs] @classmethod def find_sources(cls, path): return cls._find_sources_recursive( path, cls._ANNO_EXT, "mnist_csv", file_filter=lambda p: osp.basename(p).find("mnist_") != -1, )
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [cls._ANNO_EXT]
[docs] class MnistCsvExporter(Exporter): DEFAULT_IMAGE_EXT = ".png" def _apply_impl(self): if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): raise MediaTypeError("Media type is not an image") os.makedirs(self._save_dir, exist_ok=True) if self._save_dataset_meta: self._save_meta_file(self._save_dir) for subset_name, subset in self._extractor.subsets().items(): data = [] item_ids = {} image_sizes = {} for item in subset: anns = [a.label for a in item.annotations if a.type == AnnotationType.label] label = MnistCsvPath.NONE_LABEL if anns: label = anns[0] if item.media and self._save_media: image = item.media if not image.has_data: data.append([label, None]) else: if ( image.data.shape[0] != MnistCsvPath.IMAGE_SIZE or image.data.shape[1] != MnistCsvPath.IMAGE_SIZE ): image_sizes[len(data)] = [image.data.shape[0], image.data.shape[1]] image = image.data.reshape(-1).astype(np.uint8).tolist() image.insert(0, label) data.append(image) else: data.append([label]) if item.id != str(len(data) - 1): item_ids[len(data) - 1] = item.id anno_file = osp.join(self._save_dir, "mnist_%s.csv" % subset_name) self.save_in_csv(anno_file, data) # it is't in the original format, # this is for storng other names and sizes of images if len(item_ids) or len(image_sizes): meta = [] if len(item_ids) and len(image_sizes): # other names and sizes of images size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE] for i in range(len(data)): w, h = image_sizes.get(i, size) meta.append([item_ids.get(i, i), w, h]) elif len(item_ids): # other names of images for i in range(len(data)): meta.append([item_ids.get(i, i)]) elif len(image_sizes): # other sizes of images size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE] for i in range(len(data)): meta.append(image_sizes.get(i, size)) metafile = osp.join(self._save_dir, "meta_%s.csv" % subset_name) self.save_in_csv(metafile, meta) self.save_labels()
[docs] def save_in_csv(self, path, data): with open(path, "w", encoding="utf-8") as f: for row in data: f.write(",".join([str(p) for p in row]) + "\n")
[docs] def save_labels(self): labels_file = osp.join(self._save_dir, "labels.txt") with open(labels_file, "w", encoding="utf-8") as f: f.writelines( l.name + "\n" for l in self._extractor.categories().get(AnnotationType.label, LabelCategories()) )