Source code for datumaro.components.visualizer

# Copyright (C) 2024 Intel Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations

import logging as log
import math
import random
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union, overload

import cv2
import numpy as np
from PIL import ImageColor

from datumaro.components.annotation import (
    Annotation,
    AnnotationType,
    Bbox,
    Caption,
    Cuboid3d,
    DepthAnnotation,
    Ellipse,
    Label,
    LabelCategories,
    Mask,
    Points,
    Polygon,
    PolyLine,
    SuperResolutionAnnotation,
)
from datumaro.components.dataset_base import DatasetItem, IDataset
from datumaro.components.media import Image

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
    from matplotlib.text import Text

CAPTION_BBOX_PAD = 0.2
DEFAULT_COLOR_CYCLES: List[str] = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]


def _infer_grid_size(length: int, grid_size: Tuple[Optional[int], Optional[int]]):
    nrows, ncols = grid_size

    if nrows is None and ncols is None:
        nrows = ncols = int(math.sqrt(length))

        while nrows * ncols < length:
            nrows += 1
    elif nrows is None and ncols > 0:
        nrows = int(length / ncols)

        while nrows * ncols < length:
            nrows += 1
    elif nrows > 0 and ncols is None:
        ncols = int(length / nrows)

        while nrows * ncols < length:
            ncols += 1

    assert nrows > 0, "nrows should be a positive integer."
    assert ncols > 0, "ncols should be a positive integer."
    assert length <= nrows * ncols, "The number of ids should less then or equal to nrows * ncols."

    return nrows, ncols


[docs] class Visualizer: def __init__( self, dataset: IDataset, ignored_types: Optional[Iterable[AnnotationType]] = None, figsize: Tuple[float, float] = (8, 6), color_cycles: Optional[List[str]] = None, bbox_linewidth: float = 1.0, text_y_offset: float = 1.5, alpha: float = 1.0, show_plot_title: bool = True, ) -> None: """ Visualizer for Datumaro annotations Parameters ---------- dataset: Datumaro dataset to visualize its items. ignored_types: Categories of labels. It is used to extract label name by label id. figsize: Pyplot Figure instance used to draw annotation. color_cycles: Color cycle corresponding to each label ID. If the length of the color cycle is less than the label ID, then the label ID exceeding the color cycle length is assgined by the following rule. color = color_cycles[label_id % len(color_cycles)] bbox_linewidth: Line width for Bbox, Polygon and PolyLine annotation text_y_offset: Offset of y axis for texts. The higher value puts the text in the upper place of the annotation. alpha: Transparency value when drawing annotations. It should be in [0, 1]. If alpha=0, we do not draw any annotations. show_plot_title: If True, show the plot title formatted as "ID: {item_id}, Subset: {subset}". Otherwise, hide the plot title. """ self.dataset = dataset self.figsize = figsize self.ignored_types = set(ignored_types) if ignored_types is not None else set() self.color_cycles = color_cycles if color_cycles is not None else DEFAULT_COLOR_CYCLES self.bbox_linewidth = bbox_linewidth self.text_y_offset = text_y_offset assert 0.0 <= alpha <= 1.0, "alpha should be in [0, 1]." self.alpha = alpha self.show_plot_title = show_plot_title self._items = [item for item in self.dataset] @property def draw_only_image(self): """ If self.alpha = 0, we do not overdraw any annotation over the image. """ return self.alpha == 0.0 def _draw( self, ann: Annotation, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: """ Draw annotation according to it's annotation type. Parameters ---------- ann: Annotation entity to draw. label_categories: Categories of labels. It is used to extract label name by label id. fig: Pyplot Figure instance used to draw annotation. ax: Pyplot Axes instance used to draw annotation. context: It includes previously drawing history for each annotation type. Currently, it is necessary to avoid drawing again over an already drawn place. For example, multi label dataset has multiple Labels in a DatasetItem. If we don't keep context, it has to draw Label at the same place again and again. """ if ann.type == AnnotationType.label: return self._draw_label(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.mask: return self._draw_mask(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.points: return self._draw_points(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.polygon: return self._draw_polygon(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.polyline: return self._draw_polygon(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.bbox: return self._draw_bbox(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.caption: return self._draw_caption(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.cuboid_3d: return self._draw_cuboid_3d(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.super_resolution_annotation: return self._draw_super_resolution_annotation(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.depth_annotation: return self._draw_depth_annotation(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.ellipse: return self._draw_ellipse(ann, label_categories, fig, ax, context) if ann.type == AnnotationType.hash_key: return None # warning instead of raising an error for unsupported annotation types. log.warning(f"Ignore unknown ann.type={ann.type}") def _get_color(self, ann: Annotation) -> str: color = self.color_cycles[ann.label % len(self.color_cycles)] return color def _sort_by_z_order(self, annotations: List[Annotation]) -> List[Annotation]: def _sort_key(ann: Annotation): z_order = getattr(ann, "z_order", -1) return z_order return sorted(annotations, key=_sort_key)
[docs] def get_random_items(self, n_samples: int) -> List[DatasetItem]: """Get random samples from the dataset""" if n_samples >= len(self.dataset): raise ValueError( f"n_samples={n_samples} should be less than the dataset size ({len(self.dataset)})." ) # Disable B311: random - used for general random sampling not for security/crypto return random.choices(self._items, k=n_samples) # nosec B311
@overload def vis_gallery( self, ids: List[str], ann_ids: List[int], subsets: Union[str, List[str]], *, grid_size: Tuple[Optional[int], Optional[int]] = (None, None), ): """Visualize several :class:`DatasetItem` as a gallery. If `ann_ids` is given, only draw one annotation matching `ann_id` per item. For example, if `ids = ["item_0", "item_1"]` and `ann_ids = [2, 3]`, The item with `item.id` = "item_0" will only draw the annotation with `ann.id` = 2 and will not draw the others. Parameters ---------- ids A list of :class:`DatasetItem`'s ID to visualize ann_ids A list of :class:`Annotation`'s ID to visualize subsets A list of :class:`DatasetItem`'s subset name to visualize. If a string is given, it is automatically expanded into a list up to the length of `ids`. grid_size Grid size of the gallery plot. If `None`, we automatically infer its size. Return ------ :class:`Figure` include visualization plots. """ ... @overload def vis_gallery( self, ids: List[str], subsets: Union[str, List[str]], *, grid_size: Tuple[Optional[int], Optional[int]] = (None, None), ): """Visualize several :class:`DatasetItem` as a gallery Parameters ---------- ids A list of :class:`DatasetItem`'s ID to visualize subsets A list of :class:`DatasetItem`'s subset name to visualize. If a string is given, it is automatically expanded into a list up to the length of `ids`. grid_size Grid size of the gallery plot. If `None`, we automatically infer its size. Return ------ :class:`Figure` include visualization plots. """ ... @overload def vis_gallery( self, items: List[DatasetItem], *, grid_size: Tuple[Optional[int], Optional[int]] = (None, None), ): """Visualize several :class:`DatasetItem` as a gallery Parameters ---------- items A list of :class:`DatasetItem` to visualize Return ------ :class:`Figure` include visualization plots. """ ... @overload def vis_one_sample( self, item_id: str, subset: str, *, ann_id: Optional[int] = None, ax: Optional[Axes] = None, ) -> Figure: """Visualize one :class:`DatasetItem` Parameters ---------- item_id ID of :class:`DatasetItem` to visualize subset Subset name of :class:`DatasetItem` to visualize ann_id If not `None`, only draw an annotation which has `id=ann_id`. ax If not `None`, draw on `ax` instead of creating a new one Return ------ :class:`Figure` include visualization plot of the :class:`DatasetItem`. """ ... @overload def vis_one_sample( self, item: DatasetItem, *, ann_id: Optional[int] = None, ax: Optional[Axes] = None, ) -> Figure: """Visualize one :class:`DatasetItem` Parameters ---------- item :class:`DatasetItem` to visualize ann_id If not `None`, only draw an annotation which has `id=ann_id`. ax If not `None`, draw on `ax` instead of creating a new one Return ------ :class:`Figure` include visualization plot of the :class:`DatasetItem`. """ ...
[docs] def vis_one_sample( self, *inputs, ann_id: Optional[int] = None, ax: Optional[Axes] = None, ) -> Figure: """Visualize one dataset item""" import matplotlib.pyplot as plt if len(inputs) == 1: item_id, subset = None, None (item,) = inputs elif len(inputs) == 2: item_id, subset = inputs item = None if ax is None: fig = plt.figure(figsize=self.figsize) ax = plt.gca() else: fig = ax.get_figure() plt.sca(ax) def _parse_inputs( item_id: Optional[str], subset: Optional[str], item: Optional[DatasetItem] ) -> Tuple[str, str]: if item_id is not None and subset is not None: return item_id, subset elif item is not None: item_id = item.id subset = item.subset return item_id, subset raise ValueError( f"item_id={item_id}, subset={subset}, and item={item} is an invalid input." ) item_id, subset = _parse_inputs(item_id, subset, item) item: DatasetItem = self.dataset.get(item_id, subset) assert item is not None, f"Cannot find id={item_id}, subset={subset}" assert ( item is not Image ), f"Media type should be Image, Current media type={type(item.media)}" img = item.media.data.astype(np.uint8) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) ax.imshow(img) if self.show_plot_title: self._plot_title(ax, item_id, subset) ax.set_axis_off() if self.draw_only_image: return fig annotations = self._sort_by_z_order(item.annotations) categories = self.dataset.categories() label_categories = ( self.dataset.categories()[AnnotationType.label] if AnnotationType.label in categories else None ) context = defaultdict(list) for ann in annotations: if ann.type in self.ignored_types: ignore_type = AnnotationType(ann.type).name msg = f"{ignore_type} in self.ignored_types. Skip it." warnings.warn(msg) continue if ann_id is None or ann_id == ann.id: self._draw(ann, label_categories, fig, ax, context[ann.type]) return fig
def _plot_title(self, ax: Axes, item_id: str, subset: str) -> None: width = ax.transAxes.transform_point((1, 0))[0] - ax.transAxes.transform_point((0, 0))[0] text = ax.set_title(f"ID: {item_id}, Subset: {subset}", loc="center", wrap=True) text.__get_wrapped_text = text._get_wrapped_text def _get_wrapped_text(): wrapped_text = text.__get_wrapped_text() text._text = wrapped_text return wrapped_text text._get_wrapped_text = _get_wrapped_text text._get_wrap_line_width = lambda: width def _draw_label( self, ann: Label, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: label_text = label_categories[ann.label].name if label_categories is not None else ann.label color = self._get_color(ann) if len(context) == 0: x, y = 0.01, 0.99 else: # Draw below the previously drawn label. text: Text = context[-1] # We can know the position of text bbox only if drawing it actually. # https://stackoverflow.com/a/41271773/16880031 fig.canvas.draw() bbox = text.get_window_extent() bbox = ax.transAxes.inverted().transform_bbox(bbox) x, y = 0.01, bbox.y0 text = ax.text(x, y, label_text, ha="left", va="top", color=color, transform=ax.transAxes) context.append(text) def _draw_mask( self, ann: Mask, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: h, w = ann.image.shape source = ann.image if source.dtype != bool: warnings.warn( f"Mask should has dtype == bool, but its dtype == {source.dtype}. " "Try to change it to bool dtype." ) source = source.astype(bool) mask_map = np.zeros((h, w, 4), dtype=np.uint8) color = self._get_color(ann) rgba_color = (*ImageColor.getcolor(color, "RGB"), 0.0) mask_map[source] = rgba_color mask_map[source, 3] = int(255 * self.alpha) ax.imshow(mask_map) def _draw_points( self, ann: Points, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: label_text = label_categories[ann.label].name if label_categories is not None else ann.label color = self._get_color(ann) points = np.array(ann.points) n_points = len(points) // 2 points = points.reshape(n_points, 2) visible = [viz == Points.Visibility.visible for viz in ann.visibility] points = points[visible] ax.scatter(points[:, 0], points[:, 1], color=color) if len(points) > 0: x, y, _, _ = ann.get_bbox() ax.text(x, y - self.text_y_offset, label_text, color=color) def _draw_polygon( self, ann: Union[Polygon, PolyLine], label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: import matplotlib.patches as patches label_text = label_categories[ann.label].name if label_categories is not None else ann.label color = self._get_color(ann) points = np.array(ann.points) n_points = len(points) // 2 points = points.reshape(n_points, 2) polyline = patches.Polygon( points, fill=False, linewidth=self.bbox_linewidth, edgecolor=color, ) ax.add_patch(polyline) if isinstance(ann, Polygon): polygon = patches.Polygon( points, fill=True, facecolor=color if isinstance(ann, Polygon) else "none", alpha=self.alpha, ) ax.add_patch(polygon) x, y, _, _ = ann.get_bbox() ax.text(x, y - self.text_y_offset, label_text, color=color) def _draw_bbox( self, ann: Bbox, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: import matplotlib.patches as patches label_text = label_categories[ann.label].name if label_categories is not None else ann.label color = self._get_color(ann) rect = patches.Rectangle( (ann.x, ann.y), ann.w, ann.h, linewidth=self.bbox_linewidth, edgecolor=color, facecolor="none", ) ax.text(ann.x, ann.y - self.text_y_offset, label_text, color=color) ax.add_patch(rect) def _draw_caption( self, ann: Caption, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: if len(context) == 0: x, y = 0.5, 0.01 else: # Draw on the top of the previously drawn caption. text: Text = context[-1] bbox = text.get_bbox_patch() # We can know the position of text bbox only if drawing it actually. # https://stackoverflow.com/a/41271773/16880031 fig.canvas.draw() drawed_bbox = bbox.get_bbox() x, y = bbox.get_transform().transform( [0, (1.0 + 2 * CAPTION_BBOX_PAD) * drawed_bbox.height] ) x, y = ax.transAxes.inverted().transform([x, y]) x, y = 0.5, y width = ax.transAxes.transform_point((1, 0))[0] - ax.transAxes.transform_point((0, 0))[0] text = ax.text( x, y, ann.caption, ha="center", va="bottom", wrap=True, transform=ax.transAxes, bbox={"facecolor": "white", "alpha": self.alpha}, ) text.get_bbox_patch().set_boxstyle(f"Round,pad={CAPTION_BBOX_PAD}") text._get_wrap_line_width = lambda: width context.append(text) def _draw_cuboid_3d( self, ann: Cuboid3d, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: raise NotImplementedError(f"{ann.type} is not implemented yet.") def _draw_super_resolution_annotation( self, ann: SuperResolutionAnnotation, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: assert ( len(context) == 0 ), "It cannot visualize more than one SuperResolutionAnnotation per item." warnings.warn( "SuperResolutionAnnotation overdraws the high-resolution image over the original image. " "If you want to see the original image, set alpha=0." ) hi_res_img = ann.image.data im = ax.imshow(hi_res_img, alpha=self.alpha) context.append(im) def _draw_depth_annotation( self, ann: DepthAnnotation, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: from mpl_toolkits.axes_grid1 import make_axes_locatable assert len(context) == 0, "It cannot visualize more than one DepthAnnotation per item." depth = ann.image.data im = ax.imshow(depth, alpha=self.alpha) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im, cax) context.append(im) def _draw_ellipse( self, ann: Ellipse, label_categories: Optional[LabelCategories], fig: Figure, ax: Axes, context: List, ) -> None: import matplotlib.patches as patches label_text = label_categories[ann.label].name if label_categories is not None else ann.label color = self._get_color(ann) ellipse = patches.Ellipse( xy=(ann.c_x, ann.c_y), width=ann.w, height=ann.h, linewidth=self.bbox_linewidth, edgecolor=color, facecolor="none", ) ax.text(ann.x1, ann.y1 - self.text_y_offset, label_text, color=color) ax.add_patch(ellipse)