Source code for otx.core.ov.ops.generation
"""Generation-related module for otx.core.ov.ops."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from otx.core.ov.ops.builder import OPS
from otx.core.ov.ops.op import Attribute, Operation
from otx.core.ov.ops.type_conversions import _ov_to_torch
@dataclass
class RangeV4Attribute(Attribute):
"""RangeV4Attribute class."""
output_type: str
[docs]
@OPS.register()
class RangeV4(Operation[RangeV4Attribute]):
"""RangeV4 class."""
TYPE = "Range"
VERSION = "opset1"
ATTRIBUTE_FACTORY = RangeV4Attribute
[docs]
def forward(self, start, stop, step):
"""RangeV4's forward function."""
dtype = _ov_to_torch[self.attrs.output_type]
return torch.arange(
start=start,
end=stop,
step=step,
dtype=dtype,
device=start.device,
requires_grad=False,
)