Source code for otx.api.entities.tensor
"""This module implements the Tensor entity."""
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from typing import Tuple
import numpy as np
from otx.api.entities.metadata import IMetadata
[docs]
class TensorEntity(IMetadata):
"""Represents a metadata of tensor type in OTX.
Args:
name: name of metadata
numpy: the numpy data of the tensor
"""
def __init__(self, name: str, numpy: np.ndarray):
self.name = name
# Copying Numpy array as it points to the same memory address
self._numpy = np.copy(numpy)
@property
def numpy(self) -> np.ndarray:
"""Returns the numpy representation of the tensor."""
return self._numpy
@numpy.setter
def numpy(self, value):
self._numpy = value
@property
def shape(self) -> Tuple[int, ...]:
"""Returns the shape of the tensor."""
return self._numpy.shape
def __eq__(self, other):
"""Returns True if the tensors are equal."""
if isinstance(other, TensorEntity):
return np.array_equal(self.numpy, other.numpy)
return False
def __str__(self):
"""Returns the string representation of the tensor."""
return f"{self.__class__.__name__}(name={self.name}, shape={self.shape})"
def __repr__(self):
"""Returns the representation of the tensor."""
return f"{self.__class__.__name__}(name={self.name})"