# Copyright (C) 2021-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
import logging as log
import os
import os.path as osp
from copy import deepcopy
# Disable B406: import_xml_sax - the library is used for writing
from xml.sax.saxutils import XMLGenerator # nosec
from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.dataset_item_storage import ItemStatus
from datumaro.components.errors import DatasetExportError, MediaTypeError
from datumaro.components.exporter import Exporter
from datumaro.components.media import PointCloud
from datumaro.util import cast
from datumaro.util.image import find_images
from .format import KittiRawPath, OcclusionStates, PoseStates, TruncationStates
class _XmlAnnotationWriter:
# Format constants
_tracking_level = 0
_tracklets_class_id = 0
_tracklets_version = 0
_tracklet_class_id = 1
_tracklet_version = 1
_poses_class_id = 2
_poses_version = 0
_pose_class_id = 3
_pose_version = 1
# XML headers
_header = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"""
_doctype = "<!DOCTYPE boost_serialization>"
def __init__(self, file, tracklets):
self._file = file
self._tracklets = tracklets
self._xmlgen = XMLGenerator(self._file, encoding="utf-8")
self._level = 0
# See reference for section headers here:
# https://www.boost.org/doc/libs/1_40_0/libs/serialization/doc/traits.html
# XML archives have regular structure, so we only include headers once
self._add_tracklet_header = True
self._add_poses_header = True
self._add_pose_header = True
def _indent(self, newline=True):
if newline:
self._xmlgen.ignorableWhitespace("\n")
self._xmlgen.ignorableWhitespace(" " * self._level)
def _add_headers(self):
self._file.write(self._header)
self._indent(newline=True)
self._file.write(self._doctype)
def _open_serialization(self):
self._indent(newline=True)
self._xmlgen.startElement(
"boost_serialization", {"version": "9", "signature": "serialization::archive"}
)
def _close_serialization(self):
self._indent(newline=True)
self._xmlgen.endElement("boost_serialization")
def _add_count(self, count):
self._indent(newline=True)
self._xmlgen.startElement("count", {})
self._xmlgen.characters(str(count))
self._xmlgen.endElement("count")
def _add_item_version(self, version):
self._indent(newline=True)
self._xmlgen.startElement("item_version", {})
self._xmlgen.characters(str(version))
self._xmlgen.endElement("item_version")
def _open_tracklets(self, tracklets):
self._indent(newline=True)
self._xmlgen.startElement(
"tracklets",
{
"version": str(self._tracklets_version),
"tracking_level": str(self._tracking_level),
"class_id": str(self._tracklets_class_id),
},
)
self._level += 1
self._add_count(len(tracklets))
self._add_item_version(self._tracklet_version)
def _close_tracklets(self):
self._level -= 1
self._indent(newline=True)
self._xmlgen.endElement("tracklets")
def _open_tracklet(self):
self._indent(newline=True)
if self._add_tracklet_header:
self._xmlgen.startElement(
"item",
{
"version": str(self._tracklet_class_id),
"tracking_level": str(self._tracking_level),
"class_id": str(self._tracklet_class_id),
},
)
self._add_tracklet_header = False
else:
self._xmlgen.startElement("item", {})
self._level += 1
def _close_tracklet(self):
self._level -= 1
self._indent(newline=True)
self._xmlgen.endElement("item")
def _add_tracklet(self, tracklet):
self._open_tracklet()
for key, value in tracklet.items():
if key == "poses":
self._add_poses(value)
elif key == "attributes":
self._add_attributes(value)
else:
self._indent(newline=True)
self._xmlgen.startElement(key, {})
self._xmlgen.characters(str(value))
self._xmlgen.endElement(key)
self._close_tracklet()
def _open_poses(self, poses):
self._indent(newline=True)
if self._add_poses_header:
self._xmlgen.startElement(
"poses",
{
"version": str(self._poses_version),
"tracking_level": str(self._tracking_level),
"class_id": str(self._poses_class_id),
},
)
self._add_poses_header = False
else:
self._xmlgen.startElement("poses", {})
self._level += 1
self._add_count(len(poses))
self._add_item_version(self._poses_version)
def _close_poses(self):
self._level -= 1
self._indent(newline=True)
self._xmlgen.endElement("poses")
def _add_poses(self, poses):
self._open_poses(poses)
for pose in poses:
self._add_pose(pose)
self._close_poses()
def _open_pose(self):
self._indent(newline=True)
if self._add_pose_header:
self._xmlgen.startElement(
"item",
{
"version": str(self._pose_version),
"tracking_level": str(self._tracking_level),
"class_id": str(self._pose_class_id),
},
)
self._add_pose_header = False
else:
self._xmlgen.startElement("item", {})
self._level += 1
def _close_pose(self):
self._level -= 1
self._indent(newline=True)
self._xmlgen.endElement("item")
def _add_pose(self, pose):
self._open_pose()
for key, value in pose.items():
if key == "attributes":
self._add_attributes(value)
elif key != "frame_id":
self._indent(newline=True)
self._xmlgen.startElement(key, {})
self._xmlgen.characters(str(value))
self._xmlgen.endElement(key)
self._close_pose()
def _open_attributes(self):
self._indent(newline=True)
self._xmlgen.startElement("attributes", {})
self._level += 1
def _close_attributes(self):
self._level -= 1
self._indent(newline=True)
self._xmlgen.endElement("attributes")
def _add_attributes(self, attributes):
self._open_attributes()
for name, value in attributes.items():
self._add_attribute(name, value)
self._close_attributes()
def _open_attribute(self):
self._indent(newline=True)
self._xmlgen.startElement("attribute", {})
self._level += 1
def _close_attribute(self):
self._level -= 1
self._indent(newline=True)
self._xmlgen.endElement("attribute")
def _add_attribute(self, name, value):
self._open_attribute()
self._indent(newline=True)
self._xmlgen.startElement("name", {})
self._xmlgen.characters(name)
self._xmlgen.endElement("name")
self._xmlgen.startElement("value", {})
self._xmlgen.characters(str(value))
self._xmlgen.endElement("value")
self._close_attribute()
def write(self):
self._add_headers()
self._open_serialization()
self._open_tracklets(self._tracklets)
for tracklet in self._tracklets:
self._add_tracklet(tracklet)
self._close_tracklets()
self._close_serialization()
[docs]
class KittiRawExporter(Exporter):
DEFAULT_IMAGE_EXT = ".jpg"
[docs]
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
"--reindex",
action="store_true",
help="Assign new indices to frames and tracks. "
"Allows annotations without 'track_id' (default: %(default)s)",
)
parser.add_argument(
"--allow-attrs",
action="store_true",
help="Allow writing annotation attributes (default: %(default)s)",
)
return parser
def __init__(self, extractor, save_dir, reindex=False, allow_attrs=False, **kwargs):
super().__init__(extractor, save_dir, **kwargs)
self._reindex = reindex
self._builtin_attrs = KittiRawPath.BUILTIN_ATTRS | KittiRawPath.SPECIAL_ATTRS
self._allow_attrs = allow_attrs
def _create_tracklets(self, subset):
tracks = {} # track_id -> track
name_mapping = {} # frame_id -> name
for frame_id, item in enumerate(subset):
frame_id = self._write_item(item, frame_id)
if frame_id in name_mapping:
raise DatasetExportError(
"Item %s: frame id %s is repeated in the dataset" % (item.id, frame_id)
)
name_mapping[frame_id] = item.id
for ann in item.annotations:
if ann.type != AnnotationType.cuboid_3d:
continue
if ann.label is None:
log.warning(
"Item %s: skipping a %s%s with no label",
item.id,
ann.type.name,
"(#%s) " % ann.id if ann.id is not None else "",
)
continue
label = self._get_label(ann.label).name
track_id = cast(ann.attributes.get("track_id"), int, None)
if self._reindex and track_id is None:
# In this format, track id is not used for anything except
# annotation grouping. So we only need to pick a definitely
# unused id. A negative one, for example.
track_id = -(len(tracks) + 1)
if track_id is None:
raise DatasetExportError(
"Item %s: expected track annotations "
"having 'track_id' (integer) attribute. "
"Use --reindex to export single shapes." % item.id
)
track = tracks.get(track_id)
if not track:
track = {
"objectType": label,
"h": ann.scale[0],
"w": ann.scale[1],
"l": ann.scale[2],
"first_frame": frame_id,
"poses": [],
"finished": 1, # keep last
}
tracks[track_id] = track
else:
if [track["h"], track["w"], track["l"]] != ann.scale:
# Tracks have fixed scale in the format
raise DatasetExportError(
"Item %s: mismatching track shapes, "
"track id %s" % (item.id, track_id)
)
if track["objectType"] != label:
raise DatasetExportError(
"Item %s: mismatching track labels, "
"track id %s: %s vs. %s"
% (item.id, track_id, track["objectType"], label)
)
# If there is a skip in track frames, add missing as outside
if frame_id != track["poses"][-1]["frame_id"] + 1:
last_key_pose = track["poses"][-1]
last_keyframe_id = last_key_pose["frame_id"]
last_key_pose["occlusion_kf"] = 1
for i in range(last_keyframe_id + 1, frame_id):
pose = deepcopy(last_key_pose)
pose["occlusion"] = OcclusionStates.OCCLUSION_UNSET
pose["truncation"] = TruncationStates.OUT_IMAGE
pose["frame_id"] = i
track["poses"].append(pose)
occlusion = OcclusionStates.VISIBLE
if "occlusion" in ann.attributes:
occlusion = OcclusionStates(ann.attributes["occlusion"].upper())
elif "occluded" in ann.attributes:
if ann.attributes["occluded"]:
occlusion = OcclusionStates.PARTLY
truncation = TruncationStates.IN_IMAGE
if "truncation" in ann.attributes:
truncation = TruncationStates(ann.attributes["truncation"].upper())
pose = {
"tx": ann.position[0],
"ty": ann.position[1],
"tz": ann.position[2],
"rx": ann.rotation[0],
"ry": ann.rotation[1],
"rz": ann.rotation[2],
"state": PoseStates.LABELED.value,
"occlusion": occlusion.value,
"occlusion_kf": int(ann.attributes.get("keyframe", False) is True),
"truncation": truncation.value,
"amt_occlusion": -1,
"amt_border_l": -1,
"amt_border_r": -1,
"amt_occlusion_kf": -1,
"amt_border_kf": -1,
"frame_id": frame_id,
}
if self._allow_attrs:
attributes = {}
for name, value in ann.attributes.items():
if name in self._builtin_attrs:
continue
if isinstance(value, bool):
value = "true" if value else "false"
attributes[name] = value
pose["attributes"] = attributes
track["poses"].append(pose)
self._write_name_mapping(name_mapping)
return [e[1] for e in sorted(tracks.items(), key=lambda e: e[0])]
def _write_name_mapping(self, name_mapping):
with open(
osp.join(self._save_dir, KittiRawPath.NAME_MAPPING_FILE), "w", encoding="utf-8"
) as f:
f.writelines("%s %s\n" % (frame_id, name) for frame_id, name in name_mapping.items())
def _get_label(self, label_id):
if label_id is None:
return ""
label_cat = self._extractor.categories().get(AnnotationType.label, LabelCategories())
return label_cat.items[label_id]
def _write_item(self, item, index):
if not self._reindex:
index = cast(item.attributes.get("frame"), int, index)
if self._save_media and item.media:
self._save_point_cloud(item, subdir=KittiRawPath.PCD_DIR)
images = sorted(
item.media.extra_images, key=lambda img: img.path if hasattr(img, "path") else ""
)
for i, image in enumerate(images):
if image.has_data:
image.save(
osp.join(
self._save_dir,
KittiRawPath.IMG_DIR_PREFIX + ("%02d" % i),
"data",
item.id + self._find_image_ext(image),
)
)
elif self._save_media and not item.media:
log.debug("Item '%s' has no image info", item.id)
return index
def _apply_impl(self):
if self._extractor.media_type() and self._extractor.media_type() is not PointCloud:
raise MediaTypeError("Media type is not a point cloud")
os.makedirs(self._save_dir, exist_ok=True)
if self._save_dataset_meta:
self._save_meta_file(self._save_dir)
if 1 < len(self._extractor.subsets()):
log.warning(
"Kitti RAW format supports only a single "
"subset. Subset information will be ignored on export."
)
tracklets = self._create_tracklets(self._extractor)
with open(osp.join(self._save_dir, KittiRawPath.ANNO_FILE), "w", encoding="utf-8") as f:
writer = _XmlAnnotationWriter(f, tracklets)
writer.write()
[docs]
@classmethod
def patch(cls, dataset, patch, save_dir, **kwargs):
conv = cls(patch.as_dataset(dataset), save_dir=save_dir, **kwargs)
conv.apply()
pcd_dir = osp.abspath(osp.join(save_dir, KittiRawPath.PCD_DIR))
for (item_id, subset), status in patch.updated_items.items():
if status != ItemStatus.removed:
item = patch.data.get(item_id, subset)
else:
item = DatasetItem(item_id, subset=subset)
if not (status == ItemStatus.removed or not item.media):
continue
pcd_path = osp.join(pcd_dir, conv._make_pcd_filename(item))
if osp.isfile(pcd_path):
os.unlink(pcd_path)
for d in os.listdir(save_dir):
image_dir = osp.join(save_dir, d, "data", osp.dirname(item.id))
if d.startswith(KittiRawPath.IMG_DIR_PREFIX) and osp.isdir(image_dir):
for p in find_images(image_dir):
if osp.splitext(osp.basename(p))[0] == osp.basename(item.id):
os.unlink(p)