"""Base MMDataset for Segmentation Task."""
# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence
import numpy as np
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
from mmseg.datasets.pipelines import Compose
from otx.algorithms.common.utils.data import get_old_new_img_indices
from otx.api.entities.dataset_item import DatasetItemEntity
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.label import LabelEntity
from otx.api.utils.segmentation_utils import mask_from_dataset_item
# pylint: disable=invalid-name, too-many-locals, too-many-instance-attributes, super-init-not-called
def get_annotation_mmseg_format(
dataset_item: DatasetItemEntity,
labels: List[LabelEntity],
use_otx_adapter: bool = True,
) -> dict:
"""Function to convert a OTX annotation to mmsegmentation format.
This is used both in the OTXDataset class defined in this file
as in the custom pipeline element 'LoadAnnotationFromOTXDataset'
:param dataset_item: DatasetItem for which to get annotations
:param labels: List of labels in the project
:return dict: annotation information dict in mmseg format
"""
gt_seg_map = mask_from_dataset_item(dataset_item, labels, use_otx_adapter)
gt_seg_map = gt_seg_map.squeeze(2).astype(np.uint8)
ann_info = dict(gt_semantic_seg=gt_seg_map)
return ann_info
@DATASETS.register_module()
class _OTXSegDataset(CustomDataset, metaclass=ABCMeta):
"""Wrapper that allows using a OTX dataset to train mmsegmentation models.
This wrapper is not based on the filesystem,
but instead loads the items here directly from the OTX Dataset object.
The wrapper overwrites some methods of the CustomDataset class: prepare_train_img, prepare_test_img and prepipeline
Naming of certain attributes might seem a bit peculiar but this is due to the conventions set in CustomDataset. For
instance, CustomDatasets expects the dataset items to be stored in the attribute data_infos, which is why it is
named like that and not dataset_items.
"""
class _DataInfoProxy:
"""This class is intended to be a wrapper to use it in CustomDataset-derived class as `self.data_infos`.
Instead of using list `data_infos` as in CustomDataset, our implementation of dataset OTXDataset
uses this proxy class with overriden __len__ and __getitem__; this proxy class
forwards data access operations to otx_dataset and converts the dataset items to the view
convenient for mmsegmentation.
"""
def __init__(
self,
otx_dataset,
labels=None,
**kwargs, # pylint: disable=unused-argument
):
self.otx_dataset = otx_dataset
self.labels = labels
self.label_idx = {label.id: i for i, label in enumerate(labels)}
def __len__(self):
return len(self.otx_dataset)
def __getitem__(self, index):
"""Prepare a dict 'data_info' that is expected by the mmseg pipeline to handle images and annotations.
:return data_info: dictionary that contains the image and image metadata, as well as the labels of
the objects in the image
"""
dataset = self.otx_dataset
item = dataset[index]
ignored_labels = np.array([self.label_idx[lbs.id] + 1 for lbs in item.ignored_labels])
data_info = dict(
dataset_item=item,
width=item.width,
height=item.height,
index=index,
ann_info=dict(labels=self.labels),
ignored_labels=ignored_labels,
)
return data_info
def __init__(
self,
otx_dataset: DatasetEntity,
pipeline: Sequence[dict],
classes: Optional[List[str]] = None,
test_mode: bool = False,
use_otx_adapter: bool = True,
):
self.otx_dataset = otx_dataset
self.test_mode = test_mode
self.ignore_index = 255
self.reduce_zero_label = False
self.label_map = None
self.use_otx_adapter = use_otx_adapter
dataset_labels = self.otx_dataset.get_labels(include_empty=False)
self.project_labels = self.filter_labels(dataset_labels, classes)
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes, None)
# Instead of using list data_infos as in CustomDataset, this implementation of dataset
# uses a proxy class with overriden __len__ and __getitem__; this proxy class
# forwards data access operations to otx_dataset.
# Note that list `data_infos` cannot be used here, since OTX dataset class does not have interface to
# get only annotation of a data item, so we would load the whole data item (including image)
# even if we need only checking aspect ratio of the image; due to it
# this implementation of dataset does not uses such tricks as skipping images with wrong aspect ratios or
# small image size, since otherwise reading the whole dataset during initialization will be required.
self.data_infos = _OTXSegDataset._DataInfoProxy(self.otx_dataset, self.project_labels)
self.pipeline = Compose(pipeline)
@staticmethod
def filter_labels(all_labels: List[LabelEntity], label_names: List[str]) -> List[LabelEntity]:
"""Filter and collect actual label entities."""
filtered_labels = []
for label_name in label_names:
matches = [label for label in all_labels if label.name == label_name]
if len(matches) == 0:
continue
assert len(matches) == 1
filtered_labels.append(matches[0])
return filtered_labels
def __len__(self):
"""Total number of samples of data."""
return len(self.data_infos)
def pre_pipeline(self, results: Dict[str, Any]):
"""Prepare results dict for pipeline."""
results["seg_fields"] = []
def prepare_train_img(self, idx: int) -> dict:
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys introduced by pipeline.
"""
item = self.data_infos[idx]
self.pre_pipeline(item)
out = self.pipeline(item)
return out
def prepare_test_img(self, idx: int) -> dict:
"""Get testing data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys introduced by pipeline.
"""
item = self.data_infos[idx]
self.pre_pipeline(item)
out = self.pipeline(item)
return out
def get_ann_info(self, idx: int):
"""This method is used for evaluation of predictions.
The CustomDataset class implements a method
CustomDataset.evaluate, which uses the class method get_ann_info to retrieve annotations.
:param idx: index of the dataset item for which to get the annotations
:return ann_info: dict that contains the coordinates of the bboxes and their corresponding labels
"""
dataset_item = self.otx_dataset[idx]
ann_info = get_annotation_mmseg_format(dataset_item, self.project_labels, self.use_otx_adapter)
return ann_info
def get_gt_seg_maps(self, efficient_test: bool = False):
"""Get ground truth segmentation maps for evaluation."""
gt_seg_maps = []
for item_id in range(len(self)):
ann_info = self.get_ann_info(item_id)
gt_seg_maps.append(ann_info["gt_semantic_seg"])
if efficient_test:
pass
return gt_seg_maps
[docs]
@DATASETS.register_module()
class OTXSegDataset(_OTXSegDataset, metaclass=ABCMeta):
"""Wrapper dataset that allows using a OTX dataset to train models."""
def __init__(self, **kwargs):
pipeline = []
test_mode = kwargs.get("test_mode", False)
use_otx_adapter = True
if "dataset" in kwargs:
dataset = kwargs["dataset"]
otx_dataset = dataset.otx_dataset
pipeline = dataset.pipeline
classes = dataset.labels
new_classes = dataset.new_classes
else:
otx_dataset = kwargs["otx_dataset"]
pipeline = kwargs["pipeline"]
classes = kwargs["labels"]
new_classes = kwargs.get("new_classes", [])
if test_mode is False:
self.img_indices = get_old_new_img_indices(classes, new_classes, otx_dataset)
for pipe in pipeline:
if pipe["type"] == "LoadImageFromOTXDataset" and "use_otx_adapter" in pipe:
use_otx_adapter = pipe["use_otx_adapter"]
break
if classes:
classes = [c.name for c in classes]
classes = ["background"] + classes
else:
classes = []
super().__init__(
otx_dataset=otx_dataset,
pipeline=pipeline,
classes=classes,
use_otx_adapter=use_otx_adapter,
)
self.CLASSES = [label.name for label in self.project_labels]
if "background" not in self.CLASSES:
self.CLASSES = ["background"] + self.CLASSES
if self.label_map is None:
self.label_map = {}
for i, c in enumerate(self.CLASSES):
if c not in classes:
self.label_map[i] = -1
else:
self.label_map[i] = classes.index(c)