Source code for otx.algo.utils.utils

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Utility functions for OTX algo."""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable

import torch

if TYPE_CHECKING:
    from torch import nn


def _torch_hub_model_reduce(self) -> tuple[Callable, tuple]:  # noqa: ANN001
    return (torch_hub_load, self.torch_hub_load_args)


[docs] def torch_hub_load(repo_or_dir: str, model: str) -> nn.Module: """Load a module using from 'torch.hub'. The module is modified to support pickle.""" module = torch.hub.load( repo_or_dir=repo_or_dir, model=model, ) # support pickle module.torch_hub_load_args = (repo_or_dir, model) module.__class__.__reduce__ = _torch_hub_model_reduce.__get__(module, module.__class__) return module