""".This module contains the mapper for label related entities."""
#
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import json
from typing import Dict, cast
from otx.api.entities.color import Color
from otx.api.entities.id import ID
from otx.api.entities.label import Domain, LabelEntity
from otx.api.entities.label_schema import (
LabelGroup,
LabelGroupType,
LabelSchemaEntity,
LabelTree,
)
from .datetime_mapper import DatetimeMapper
from .id_mapper import IDMapper
[docs]
class ColorMapper:
"""This class maps a `Color` entity to a serialized dictionary, and vice versa."""
[docs]
@staticmethod
def forward(instance: Color) -> dict:
"""Serializes to dict."""
return {
"red": instance.red,
"green": instance.green,
"blue": instance.blue,
"alpha": instance.alpha,
}
[docs]
@staticmethod
def backward(instance: dict) -> Color:
"""Deserializes from dict."""
return Color(instance["red"], instance["green"], instance["blue"], instance["alpha"])
[docs]
class LabelMapper:
"""This class maps a `Label` entity to a serialized dictionary, and vice versa."""
[docs]
@staticmethod
def forward(
instance: LabelEntity,
) -> dict:
"""Serializes to dict."""
return {
"_id": IDMapper().forward(instance.id_),
"name": instance.name,
"color": ColorMapper().forward(instance.color),
"hotkey": instance.hotkey,
"domain": str(instance.domain),
"creation_date": DatetimeMapper.forward(instance.creation_date),
"is_empty": instance.is_empty,
"is_anomalous": instance.is_anomalous,
}
[docs]
@staticmethod
def backward(instance: dict) -> LabelEntity:
"""Deserializes from dict."""
label_id = IDMapper().backward(instance["_id"])
domain = str(instance.get("domain"))
label_domain = Domain[domain]
label = LabelEntity(
id=label_id,
name=instance["name"],
color=ColorMapper().backward(instance["color"]),
hotkey=instance.get("hotkey", ""),
domain=label_domain,
creation_date=DatetimeMapper.backward(instance["creation_date"]),
is_empty=instance.get("is_empty", False),
is_anomalous=instance.get("is_anomalous", False),
)
return label
[docs]
class LabelGroupMapper:
"""This class maps a `LabelGroup` entity to a serialized dictionary, and vice versa."""
[docs]
@staticmethod
def forward(instance: LabelGroup) -> dict:
"""Serializes to dict."""
return {
"_id": IDMapper().forward(instance.id_),
"name": instance.name,
"label_ids": [IDMapper().forward(label.id_) for label in instance.labels],
"relation_type": instance.group_type.name,
}
[docs]
@staticmethod
def backward(instance: dict, all_labels: Dict[ID, LabelEntity]) -> LabelGroup:
"""Deserializes from dict."""
return LabelGroup(
id=IDMapper().backward(instance["_id"]),
name=instance["name"],
group_type=LabelGroupType[instance["relation_type"]],
labels=[all_labels[IDMapper().backward(label_id)] for label_id in instance["label_ids"]],
)
[docs]
class LabelTreeMapper:
"""This class maps a `LabelTree` entity to a serialized dictionary, and vice versa."""
[docs]
@staticmethod
def forward(instance: LabelTree) -> dict:
"""Serializes to dict."""
return {
"type": instance.type,
"directed": instance.directed,
"nodes": [IDMapper().forward(label.id_) for label in instance.nodes],
"edges": [(IDMapper().forward(edge[0].id_), IDMapper().forward(edge[1].id_)) for edge in instance.edges],
}
[docs]
@staticmethod
def backward(instance: dict, all_labels: Dict[ID, LabelEntity]) -> LabelTree:
"""Deserializes from dict."""
output: LabelTree
instance_type = instance["type"]
if instance_type == "tree":
output = LabelTree()
else:
raise ValueError(f"Unsupported type `{instance_type}` for label graph")
label_map = {label_id: all_labels.get(IDMapper().backward(label_id)) for label_id in instance["nodes"]}
for label in label_map.values():
if label:
output.add_node(label)
for edge in instance["edges"]:
node1 = label_map.get(edge[0])
node2 = label_map.get(edge[1])
if node1 and node2:
output.add_edge(node1, node2)
return output
[docs]
class LabelSchemaMapper:
"""This class maps a `LabelSchema` entity to a serialized dictionary, and vice versa."""
[docs]
@staticmethod
def forward(
instance: LabelSchemaEntity,
) -> dict:
"""Serializes to dict."""
label_groups = [LabelGroupMapper().forward(group) for group in instance.get_groups(include_empty=True)]
return {
"label_tree": LabelTreeMapper().forward(instance.label_tree),
"label_groups": label_groups,
"all_labels": {
IDMapper().forward(label.id_): LabelMapper().forward(label) for label in instance.get_labels(True)
},
}
[docs]
@staticmethod
def backward(instance: dict) -> LabelSchemaEntity:
"""Deserializes from dict."""
all_labels = {
IDMapper().backward(id): LabelMapper().backward(label) for id, label in instance["all_labels"].items()
}
label_tree = LabelTreeMapper().backward(instance["label_tree"], all_labels)
label_groups = [
LabelGroupMapper().backward(label_group, all_labels) for label_group in instance["label_groups"]
]
output = LabelSchemaEntity(
label_tree=cast(LabelTree, label_tree),
label_groups=label_groups,
)
return output
[docs]
def label_schema_to_bytes(label_schema: LabelSchemaEntity) -> bytes:
"""Returns json-serialized LabelSchemaEntity as bytes."""
serialized_label_schema = LabelSchemaMapper.forward(label_schema)
return json.dumps(serialized_label_schema, indent=4).encode()