Source code for otx.api.entities.shapes.shape
"""This file defines the ShapeEntity interface and the Shape abstract class."""
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import abc
import datetime
import warnings
from enum import IntEnum, auto
from typing import TYPE_CHECKING
from shapely.errors import PredicateError, TopologicalError
from shapely.geometry import Polygon as shapely_polygon
if TYPE_CHECKING:
from otx.api.entities.shapes.rectangle import Rectangle
[docs]
class GeometryException(ValueError):
"""Exception that is thrown if the geometry of a Shape is invalid."""
[docs]
class ShapeType(IntEnum):
"""Shows which type of Shape is being used."""
ELLIPSE = auto()
RECTANGLE = auto()
POLYGON = auto()
[docs]
class ShapeEntity(metaclass=abc.ABCMeta):
"""This interface represents the annotation shapes on the media given by user annotations or system analysis.
The shapes is a 2D geometric shape living in a normalized coordinate system (the values range from 0 to 1).
"""
# pylint: disable=redefined-builtin
def __init__(self, shape_type: ShapeType):
self._type = shape_type
@property
def type(self) -> ShapeType:
"""Get the type of Shape that this Shape represents."""
return self._type
[docs]
@abc.abstractmethod
def get_area(self) -> float:
"""Get the area of the shape."""
raise NotImplementedError
[docs]
@abc.abstractmethod
def intersects(self, other: "Shape") -> bool:
"""Returns true if other intersects with shape, otherwise returns false.
Args:
other (Shape): Shape to compare with
Returns:
bool: true if other intersects with shape, otherwise returns false
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def contains_center(self, other: "ShapeEntity") -> bool:
"""Checks whether the center of the 'other' shape is located in the shape.
Args:
other (ShapeEntity): Shape to compare with
Returns:
bool: true if the center of the 'other' shape is located in the shape, otherwise returns false
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def normalize_wrt_roi_shape(self, roi_shape: "Rectangle") -> "Shape":
"""The inverse of denormalize_wrt_roi_shape.
Transforming shape from the `roi` coordinate system to the normalized coordinate system.
This is used when the tasks want to save the analysis results.
For example in Detection -> Segmentation pipeline, the analysis results of segmentation
needs to be normalized to the roi (bounding boxes) coming from the detection.
Args:
roi_shape (Rectangle): Shape of the roi.
Returns:
Shape: Shape in the normalized coordinate system.
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def denormalize_wrt_roi_shape(self, roi_shape: "Rectangle") -> "ShapeEntity":
"""The inverse of normalize_wrt_roi_shape.
Transforming shape from the normalized coordinate system to the `roi` coordinate system.
This is used to pull ground truth during training process of the tasks.
Examples given in the Shape implementations.
Args:
roi_shape (Rectangle): Shape of the roi.
Returns:
ShapeEntity: Shape in the `roi` coordinate system.
"""
raise NotImplementedError
@abc.abstractmethod
def _as_shapely_polygon(self) -> shapely_polygon:
"""Convert shape to a shapely polygon.
Shapely polygons are within the SDK used to calculate the intersection between Shapes.
It is also used in the SDK to find shapes that are visible within a given ROI.
Returns:
shapely_polygon: Shapely polygon representation of the shape.
"""
raise NotImplementedError
[docs]
class Shape(ShapeEntity):
"""Base class for Shape entities."""
# pylint: disable=redefined-builtin, too-many-arguments; Requires refactor
def __init__(self, shape_type: ShapeType, modification_date: datetime.datetime):
super().__init__(shape_type=shape_type)
self.modification_date = modification_date
def __repr__(self):
"""Returns the date of the last modification of the shape."""
return f"Shape with modification date:('{self.modification_date}')"
[docs]
def get_area(self) -> float:
"""Get the area of the shape."""
raise NotImplementedError
# pylint: disable=protected-access
[docs]
def intersects(self, other: "Shape") -> bool:
"""Returns True, if other intersects with shape, otherwise returns False."""
polygon_roi = self._as_shapely_polygon()
polygon_shape = other._as_shapely_polygon()
try:
return polygon_roi.intersects(polygon_shape)
except (PredicateError, TopologicalError) as exception:
raise GeometryException(
f"The intersection between the shapes {self} and {other} could not be computed: " f"{exception}."
) from exception
# pylint: disable=protected-access
[docs]
def contains_center(self, other: "ShapeEntity") -> bool:
"""Checks whether the center of the 'other' shape is located in the shape.
Args:
other (ShapeEntity): Shape to compare with.
Returns:
bool: Boolean that indicates whether the center of the other shape is located in the shape
"""
polygon_roi = self._as_shapely_polygon()
polygon_shape = other._as_shapely_polygon()
return polygon_roi.contains(polygon_shape.centroid)
def _validate_coordinates(self, x: float, y: float) -> bool:
"""Check if coordinate is valid.
Checks whether the values for a given x,y coordinate pair lie within the range of (0,1) that is expected for
the normalized coordinate system. Issues a warning if the coordinates are out of bounds.
Args:
x (float): x-coordinate to validate
y (float): y-coordinate to validate
Returns:
bool: ``True`` if coordinates are within expected range, ``False`` otherwise
"""
if not ((0.0 <= x <= 1.0) and (0.0 <= y <= 1.0)):
warnings.warn(
f"{type(self).__name__} coordinates (x={x}, y={y}) are out of bounds, a normalized "
f"coordinate system is assumed. All coordinates are expected to be in range (0,1).",
UserWarning,
)
return False
return True
def __hash__(self):
"""Returns the hash of shape."""
return hash(str(self))