Source code for otx.core.exporter.diffusion

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Exporter for diffusion models that uses native torch and OpenVINO conversion tools."""

from __future__ import annotations

import logging as log
from pathlib import Path

import onnx
import openvino
import torch

from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.model.base import OTXModel
from otx.core.types.export import OTXExportFormatType
from otx.core.types.precision import OTXPrecisionType


[docs] class DiffusionOTXModelExporter(OTXNativeModelExporter): """Exporter for diffusion models that uses native torch and OpenVINO conversion tools."""
[docs] def export( # type: ignore[override] self, model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", export_format: OTXExportFormatType = OTXExportFormatType.OPENVINO, precision: OTXPrecisionType = OTXPrecisionType.FP32, to_exportable_code: bool = False, ) -> Path: """Exports input model to the specified deployable format, such as OpenVINO IR or ONNX. Args: model (OTXModel): OTXModel to be exported output_dir (Path): path to the directory to store export artifacts base_model_name (str, optional): exported model name format (OTXExportFormatType): final format of the exported model precision (OTXExportPrecisionType, optional): precision of the exported model's weights to_exportable_code (bool, optional): whether to generate exportable code. Currently not supported by Diffusion task. Returns: Path: path to the exported model """ if export_format == OTXExportFormatType.OPENVINO: if to_exportable_code: msg = "Exportable code option is not supported and will be ignored." log.warning(msg) fn = self.to_openvino else: fn = self.to_onnx # type: ignore[assignment] return fn(model, output_dir, base_model_name, precision)
[docs] def to_openvino( self, model: OTXModel | torch.nn.Module, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, ) -> Path: """Export to OpenVINO Intermediate Representation format. In this implementation the export is done only via standard OV/ONNX tools. """ exported_model = openvino.convert_model( model, example_input=self.onnx_export_configuration["args"], input={k: v.shape for k, v in self.onnx_export_configuration["args"].items()}, ) exported_model = self._postprocess_openvino_model(exported_model) save_path = output_dir / (base_model_name + ".xml") openvino.save_model(exported_model, save_path, compress_to_fp16=(precision == OTXPrecisionType.FP16)) log.info("Converting to OpenVINO is done.") return Path(save_path)
[docs] def to_onnx( self, model: OTXModel | torch.nn.Module, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, embed_metadata: bool = True, model_type: str = "stable_diffusion", ) -> Path: """Export the given PyTorch model to ONNX format and save it to the specified output directory. Args: model (OTXModel): OTXModel to be exported. output_dir (Path): The directory where the ONNX model will be saved. base_model_name (str, optional): The base name for the exported model. Defaults to "exported_model". precision (OTXPrecisionType, optional): The precision type for the exported model. Defaults to OTXPrecisionType.FP32. embed_metadata (bool, optional): Whether to embed metadata in the ONNX model. Defaults to True. Returns: Path: The path to the saved ONNX model. """ save_path = str(output_dir / (base_model_name + ".onnx")) torch.onnx.export( model=model, f=save_path, **self.onnx_export_configuration, ) onnx_model = onnx.load(save_path) onnx_model = self._postprocess_onnx_model(onnx_model, False, precision) if self.metadata is not None and embed_metadata: export_metadata = self._extend_model_metadata(self.metadata) export_metadata[("model_info", "model_type")] = model_type onnx_model = self._embed_onnx_metadata(onnx_model, export_metadata) onnx.save(onnx_model, save_path) log.info("Converting to ONNX is done.") return Path(save_path)