Source code for datumaro.plugins.data_formats.ava.ava

# Copyright (C) 2022-2023 Intel Corporation
# SPDX-License-Identifier: MIT

import csv
import errno
import os
import os.path as osp
from typing import List, Optional

import google.protobuf.text_format as text_format

from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.format_detection import FormatDetectionConfidence, FormatDetectionContext
from datumaro.components.importer import ImportContext, Importer
from import Image
from datumaro.util.os_util import find_files

from . import ava_label_pb2

[docs] class AvaPath: IMAGE_DIR = "frames" IMAGE_EXT = ".jpg" ANNOTATION_DIR = "annotations" ANNOTATION_EXT = ".csv" ANNOTATION_PREFIX = "ava_" ANNOTATION_VERSION = "_v2.2" LABEL_LIST = ANNOTATION_PREFIX + "action_list" + ANNOTATION_VERSION + ".pbtxt" PROPOSAL_EXT = ".pkl"
[docs] class AvaBase(SubsetBase): def __init__( self, path: str, *, subset: Optional[str] = None, ctx: Optional[ImportContext] = None, ): if not osp.isfile(path): raise FileNotFoundError(errno.ENOENT, "Can't find CSV file", path) self._path = path if not subset: subset = ( osp.splitext(osp.basename(path))[0] .replace(AvaPath.ANNOTATION_PREFIX, "") .replace(AvaPath.ANNOTATION_VERSION, "") ) super().__init__(subset=subset, ctx=ctx) if path.endswith(osp.join(AvaPath.ANNOTATION_DIR, osp.basename(path))): self._rootpath = path.rsplit(AvaPath.ANNOTATION_DIR, maxsplit=1)[0] else: raise FileNotFoundError( errno.ENOENT, f"Annotation path ({path}) should be under the directory which is named {AvaPath.ANNOTATION_DIR}. " "If not, Datumaro fails to find the root path for this dataset.", ) if self._rootpath and osp.isdir(osp.join(self._rootpath, AvaPath.IMAGE_DIR)): self._images_dir = osp.join(self._rootpath, AvaPath.IMAGE_DIR) else: raise FileNotFoundError( errno.ENOENT, f"Root path ({self._rootpath}) should contain the directory which is named {AvaPath.IMAGE_DIR}. " "If not, Datumaro fails to find the image directory path.", ) self._infos = self._load_infos(osp.dirname(path)) category_path = osp.join(self._rootpath, AvaPath.ANNOTATION_DIR, AvaPath.LABEL_LIST) self._categories = self._load_categories(category_path) self._items = self._load_items(path) def _load_infos(self, path): infos = {} for file in os.listdir(path): if file.endswith(AvaPath.PROPOSAL_EXT): name = file.split(".")[0].split("_")[-1] infos[name + "_proposals"] = file return infos def _load_categories(self, category_path): if not osp.exists(category_path): raise FileNotFoundError( errno.ENOENT, f"Label lists cannot be found in ({category_path}). " "If not, Datumaro fails to import AVA action dataset.", ) with open(category_path, "r") as f: pbtxt_data = label_list = ava_label_pb2.LabelList() text_format.Parse(pbtxt_data, label_list) categories = LabelCategories() # dummy class for id 0 for ava data if label_list.label[0].label_id != 0: categories.add("no action") for node in label_list.label: categories.add( return {AnnotationType.label: categories} def _load_items(self, ann_file): items = {} with open(ann_file, "r", encoding="utf-8") as f: csvreader = csv.reader(f) datas = list(csvreader) for data in datas: video_id = data[0] timestamp = data[1] item_id = video_id + "/" + timestamp image_path = osp.join( self._images_dir, video_id, item_id.replace("/", "_") + AvaPath.IMAGE_EXT ) item = items.get(item_id) if item is None: item = DatasetItem( id=item_id, subset=self._subset, media=Image.from_file(path=image_path), ) items[item_id] = item if "excluded_timestamps" in self._subset: continue bbox = list(map(float, data[2:6])) # (x1, y1, x2, y2) label = int(data[6]) entity_id = int(data[7]) anns = item.annotations anns.append( Bbox( x=bbox[0], y=bbox[1], w=bbox[2] - bbox[0], h=bbox[3] - bbox[1], label=label, attributes={"track_id": entity_id}, ) ) return items.values()
[docs] class AvaImporter(Importer): _ANNO_EXT = "csv"
[docs] @classmethod def find_sources(cls, path): ann_files = find_files( osp.join(path, AvaPath.ANNOTATION_DIR), exts=cls._ANNO_EXT, recursive=True, max_depth=1 ) sources = [] for ann_file in ann_files: if AvaPath.ANNOTATION_PREFIX in ann_file: sources.append({"url": ann_file, "format": AvaBase.NAME}) return sources
[docs] @classmethod def detect(cls, context: FormatDetectionContext) -> FormatDetectionConfidence: super().detect(context) return FormatDetectionConfidence.MEDIUM
[docs] @classmethod def get_file_extensions(cls) -> List[str]: return [f".{cls._ANNO_EXT}"]
[docs] class AvaExporter(Exporter): DEFAULT_IMAGE_EXT = AvaPath.IMAGE_EXT def _apply_impl(self): if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image): raise MediaTypeError("Media type is not an image") save_dir = self._save_dir ann_dir = osp.join(save_dir, AvaPath.ANNOTATION_DIR) os.makedirs(ann_dir, exist_ok=True) frame_dir = osp.join(save_dir, AvaPath.IMAGE_DIR) if self._save_dataset_meta: label_categories = self._extractor.categories()[AnnotationType.label] message = ava_label_pb2.LabelList() for k, v in label_categories._indices.items(): label = ava_label_pb2.Label(name=k, label_id=v) message.label.extend([label]) # Since protobuf may not be possible to describe zero while it is valid, # the label with label_id=0 will be ignored in the written pbtxt. # But this is well interpreted as zero during reading the pbtxt. pbtxt_string = text_format.MessageToString(message) with open(osp.join(ann_dir, AvaPath.LABEL_LIST), "w") as f: f.write(pbtxt_string) for subset_name, subset in self._extractor.subsets().items(): ann_file = osp.join( ann_dir, AvaPath.ANNOTATION_PREFIX + subset_name + AvaPath.ANNOTATION_VERSION + AvaPath.ANNOTATION_EXT, ) with open(ann_file, mode="w", newline="", encoding="utf-8") as csvfile: csvwriter = csv.writer(csvfile) for item in subset: item_row ="/") if self._save_media: image_path = osp.join( osp.join(frame_dir, item_row[0]),"/", "_") + AvaPath.IMAGE_EXT, ) self._save_image( item, path=image_path, ) bboxes = [a for a in item.annotations if a.type == AnnotationType.bbox] if bboxes: for bbox in bboxes: csvwriter.writerow( item_row + [ bbox.x, bbox.y, bbox.x + bbox.w, bbox.y + bbox.h, bbox.label, bbox.attributes.get("track_id", 0), ] )