Source code for otx.algorithms.anomaly.adapters.anomalib.data.data

"""Anomaly Dataset Utils."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Dict, List, Optional, Union

import numpy as np
import torch
from anomalib.data.base.datamodule import collate_fn
from anomalib.data.utils.transform import get_transforms
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.core.datamodule import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model_template import TaskType
from otx.api.entities.shapes.polygon import Polygon
from otx.api.entities.shapes.rectangle import Rectangle
from otx.api.entities.subset import Subset
from otx.api.utils.dataset_utils import (
    contains_anomalous_images,
    split_local_global_dataset,
)
from otx.api.utils.segmentation_utils import mask_from_dataset_item
from otx.utils.logger import get_logger

logger = get_logger()


class OTXAnomalyDataset(Dataset):
    """Anomaly Dataset Adaptor.

    This class converts OTX Dataset into Anomalib dataset that
    is a sub-class of Vision Dataset.

    Args:
        config (Union[DictConfig, ListConfig]): Anomalib config
        dataset (DatasetEntity): [description]: OTX SDK Dataset

    Example:
        >>> from tests.helpers.dataset import OTXAnomalyDatasetGenerator
        >>> from otx.utils.data import AnomalyDataset

        >>> dataset_generator = OTXAnomalyDatasetGenerator()
        >>> dataset = dataset_generator.generate()
        >>> anomaly_dataset = AnomalyDataset(config=config, dataset=dataset)
        >>> anomaly_dataset[0]["image"].shape
        torch.Size([3, 256, 256])
    """

    def __init__(self, config: Union[DictConfig, ListConfig], dataset: DatasetEntity, task_type: TaskType):
        self.config = config
        self.dataset = dataset
        self.task_type = task_type

        # TODO: distinguish between train and val config here
        self.transform = get_transforms(
            config=config.dataset.transform_config.train, image_size=tuple(config.dataset.image_size), to_tensor=True
        )

    def __len__(self) -> int:
        """Get size of the dataset.

        Returns:
            int: Size of the dataset.
        """
        return len(self.dataset)

    def __getitem__(self, index: int) -> Dict[str, Union[int, Tensor]]:
        """Get dataset item.

        Args:
            index (int): Index of the dataset sample.

        Raises:
            ValueError: When the task type is not supported.

        Returns:
            Dict[str, Union[int, Tensor]]: Dataset item.
        """
        dataset_item = self.dataset[index]
        item: Dict[str, Union[int, Tensor]] = {}
        item = {"index": index}
        if self.task_type == TaskType.ANOMALY_CLASSIFICATION:
            # Detection currently relies on image labels only, meaning it'll use image
            #   threshold to find the predicted bounding boxes.
            item["image"] = self.transform(image=dataset_item.numpy)["image"]
        elif self.task_type == TaskType.ANOMALY_DETECTION:
            item["image"] = self.transform(image=dataset_item.numpy)["image"]
            item["boxes"] = torch.empty((0, 4))
            height, width = self.config.dataset.image_size
            boxes = []
            for annotation in dataset_item.get_annotations():
                if isinstance(annotation.shape, Rectangle) and not Rectangle.is_full_box(annotation.shape):
                    boxes.append(
                        Tensor(
                            [
                                annotation.shape.x1 * width,
                                annotation.shape.y1 * height,
                                annotation.shape.x2 * width,
                                annotation.shape.y2 * height,
                            ]
                        )
                    )
                if boxes:
                    item["boxes"] = torch.stack(boxes)
        elif self.task_type == TaskType.ANOMALY_SEGMENTATION:
            if any((isinstance(annotation.shape, Polygon) for annotation in dataset_item.get_annotations())):
                mask = mask_from_dataset_item(dataset_item, dataset_item.get_shapes_labels()).squeeze()
            else:
                mask = np.zeros(dataset_item.numpy.shape[:2]).astype(np.int)
            pre_processed = self.transform(image=dataset_item.numpy, mask=mask)
            item["image"] = pre_processed["image"]
            item["mask"] = pre_processed["mask"]
        else:
            raise ValueError(f"Unsupported task type: {self.task_type}")

        if len(dataset_item.get_shapes_labels()) > 0 and dataset_item.get_shapes_labels()[0].is_anomalous:
            item["label"] = 1
        else:
            item["label"] = 0
        return item


[docs] class OTXAnomalyDataModule(LightningDataModule): """Anomaly DataModule. This class converts OTX Dataset into Anomalib dataset and stores train/val/test dataloaders. Args: config (Union[DictConfig, ListConfig]): Anomalib config dataset (DatasetEntity): OTX SDK Dataset Example: >>> from tests.helpers.dataset import OTXAnomalyDatasetGenerator >>> from otx.utils.data import AnomalyDataModule >>> dataset_generator = OTXAnomalyDatasetGenerator() >>> dataset = dataset_generator.generate() >>> data_module = OTXAnomalyDataModule(config=config, dataset=dataset) >>> i, data = next(enumerate(data_module.train_dataloader())) >>> data["image"].shape torch.Size([32, 3, 256, 256]) """ def __init__(self, config: Union[DictConfig, ListConfig], dataset: DatasetEntity, task_type: TaskType) -> None: super().__init__() self.config = config self.dataset = dataset self.task_type = task_type self.train_otx_dataset: DatasetEntity self.val_otx_dataset: DatasetEntity self.test_otx_dataset: DatasetEntity self.predict_otx_dataset: DatasetEntity
[docs] def setup(self, stage: Optional[str] = None) -> None: """Setup Anomaly Data Module. Args: stage (Optional[str], optional): train/val/test stages. Defaults to None. """ if not stage == "predict": self.summary() if stage == "fit" or stage is None: self.train_otx_dataset = self.dataset.get_subset(Subset.TRAINING) self.val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION) if stage == "validate": self.val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION) if stage == "test" or stage is None: self.test_otx_dataset = self.dataset.get_subset(Subset.TESTING) if stage == "predict": self.predict_otx_dataset = self.dataset
[docs] def summary(self): """Print size of the dataset, number of anomalous images and number of normal images.""" for subset in [Subset.TRAINING, Subset.VALIDATION, Subset.TESTING]: dataset = self.dataset.get_subset(subset) num_items = len(dataset) num_normal = len([item for item in dataset if not item.get_shapes_labels()[0].is_anomalous]) num_anomalous = len([item for item in dataset if item.get_shapes_labels()[0].is_anomalous]) logger.info( "'%s' subset size: Total '%d' images. Normal: '%d', images. Anomalous: '%d' images", subset, num_items, num_normal, num_anomalous, )
[docs] def train_dataloader( self, ) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: """Train Dataloader. Returns: Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: Train dataloader. """ dataset = OTXAnomalyDataset(self.config, self.train_otx_dataset, self.task_type) return DataLoader( dataset, shuffle=False, batch_size=self.config.dataset.train_batch_size, num_workers=self.config.dataset.num_workers, collate_fn=collate_fn, )
[docs] def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: """Validation Dataloader. Returns: Union[DataLoader, List[DataLoader]]: Validation Dataloader. """ global_dataset, local_dataset = split_local_global_dataset(self.val_otx_dataset) logger.info(f"Global annotations: {len(global_dataset)}") logger.info(f"Local annotations: {len(local_dataset)}") if contains_anomalous_images(local_dataset): logger.info("Dataset contains polygon annotations. Passing masks to anomalib.") dataset = OTXAnomalyDataset(self.config, local_dataset, self.task_type) else: logger.info("Dataset does not contain polygon annotations. Not passing masks to anomalib.") dataset = OTXAnomalyDataset(self.config, global_dataset, TaskType.ANOMALY_CLASSIFICATION) return DataLoader( dataset, shuffle=False, batch_size=self.config.dataset.eval_batch_size, num_workers=self.config.dataset.num_workers, collate_fn=collate_fn, )
[docs] def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: """Test Dataloader. Returns: Union[DataLoader, List[DataLoader]]: Test Dataloader. """ dataset = OTXAnomalyDataset(self.config, self.test_otx_dataset, self.task_type) return DataLoader( dataset, shuffle=False, batch_size=self.config.dataset.test_batch_size, num_workers=self.config.dataset.num_workers, collate_fn=collate_fn, )
[docs] def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: """Predict Dataloader. Returns: Union[DataLoader, List[DataLoader]]: Predict Dataloader. """ dataset = OTXAnomalyDataset(self.config, self.predict_otx_dataset, self.task_type) return DataLoader( dataset, shuffle=False, batch_size=self.config.dataset.eval_batch_size, num_workers=self.config.dataset.num_workers, collate_fn=collate_fn, )