Source code for otx.algorithms.common.adapters.nncf.utils.utils
"""NNCF utils."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from collections import OrderedDict
from contextlib import contextmanager
from importlib.util import find_spec
import torch
_is_nncf_enabled = find_spec("nncf") is not None
[docs]
def is_nncf_enabled():
"""is_nncf_enabled."""
return _is_nncf_enabled
[docs]
def check_nncf_is_enabled():
"""check_nncf_is_enabled."""
if not is_nncf_enabled():
raise RuntimeError("Tried to use NNCF, but NNCF is not installed")
[docs]
def get_nncf_version():
"""get_nncf_version."""
if not is_nncf_enabled():
return None
import nncf
return nncf.__version__
def load_checkpoint(model, filename, map_location=None, strict=False):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
# from nncf.torch import load_state
from mmcv.runner import load_state_dict
checkpoint = torch.load(filename, map_location=map_location)
nncf_state = None
compression_state = None
# get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict):
base_state = checkpoint
elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
if "meta" in checkpoint and "nncf_meta" in checkpoint["meta"]:
nncf_state = checkpoint["state_dict"]
compression_state = checkpoint["meta"]["nncf_meta"].compression_ctrl
base_state = checkpoint["meta"]["nncf_meta"].state_to_build
else:
base_state = checkpoint["state_dict"]
else:
raise RuntimeError(f"No state_dict found in checkpoint file {filename}")
load_state_dict(model, base_state, strict=strict)
return compression_state, nncf_state
@contextmanager
def nullcontext():
"""Context which does nothing."""
yield
[docs]
def no_nncf_trace():
"""Wrapper for original NNCF no_nncf_trace context."""
if is_nncf_enabled():
from nncf.torch.dynamic_graph.context import (
no_nncf_trace as original_no_nncf_trace,
)
return original_no_nncf_trace()
return nullcontext()
def nncf_trace():
"""Trace nncf context."""
if is_nncf_enabled():
@contextmanager
def _nncf_trace():
from nncf.torch.dynamic_graph.context import get_current_context
ctx = get_current_context()
if ctx is not None and not ctx.is_tracing:
ctx.enable_tracing()
yield
ctx.disable_tracing()
else:
yield
return _nncf_trace()
return nullcontext()
[docs]
def is_in_nncf_tracing():
"""is_in_nncf_tracing."""
if not is_nncf_enabled():
return False
from nncf.torch.dynamic_graph.context import get_current_context
ctx = get_current_context()
if ctx is None:
return False
return ctx.is_tracing
[docs]
def is_accuracy_aware_training_set(nncf_config):
"""is_accuracy_aware_training_set."""
if not is_nncf_enabled():
return False
from nncf.config.utils import is_accuracy_aware_training
is_acc_aware_training_set = is_accuracy_aware_training(nncf_config)
return is_acc_aware_training_set