Source code for otx.algorithms.common.adapters.nncf.patches

"""NNCFNetwork patch util functions for mmcv models."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from contextlib import contextmanager
from functools import partial

from otx.algorithms.common.adapters.nncf.utils import nncf_trace, no_nncf_trace


[docs] @contextmanager def nncf_trace_context(self, img_metas, nncf_compress_postprocessing=True): """A context manager for nncf graph tracing.""" # onnx_export in mmdet head has a bug on GPU # it must be on CPU device_backup = next(self.parameters()).device # pylint: disable=stop-iteration-return self = self.to("cpu") if nncf_compress_postprocessing: self.forward = partial(self.forward, img_metas=img_metas, return_loss=False) else: self.forward = partial(self.forward_dummy) yield # make everything normal self.__dict__.pop("forward") self = self.to(device_backup)
[docs] def no_nncf_trace_wrapper(self, fn, *args, **kwargs): # pylint: disable=unused-argument,invalid-name """A wrapper function not to trace in NNCF.""" with no_nncf_trace(): return fn(*args, **kwargs)
[docs] def nncf_trace_wrapper(self, fn, *args, **kwargs): # pylint: disable=unused-argument,invalid-name """A wrapper function to trace in NNCF.""" with nncf_trace(): return fn(*args, **kwargs)