Source code for otx.algorithms.common.adapters.nncf.compression
"""NNCF utils."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import numpy as np
import torch
from .utils import check_nncf_is_enabled, get_nncf_version
@dataclass
class NNCFMetaState:
"""NNCF meta state wrapper."""
state_to_build: Optional[Dict[str, torch.Tensor]] = field(default=None)
data_to_build: Optional[np.ndarray] = field(default=None)
compression_ctrl: Optional[Dict[Any, Any]] = field(default=None)
def __repr__(self):
"""Repr."""
out = f"{self.__class__.__name__}("
if self.state_to_build is not None:
out += "state_to_build='<data>', "
if self.data_to_build is not None:
out += "data_to_build='<data>', "
if self.compression_ctrl is not None:
out += "compression_ctrl='<data>', "
if out[-2:] == ", ":
out = out[:-2]
out += ")"
return out
[docs]
def is_state_nncf(state):
"""Check if state_dict is NNCF state_dict.
The function uses metadata stored in a dict_state to check if the
checkpoint was the result of trainning of NNCF-compressed model.
See the function get_nncf_metadata above.
"""
return bool(state.get("meta", {}).get("nncf_enable_compression", False))
[docs]
def is_checkpoint_nncf(path):
"""Check if path is NNCF checkpoint.
The function uses metadata stored in a checkpoint to check if the
checkpoint was the result of trainning of NNCF-compressed model.
See the function get_nncf_metadata above.
"""
try:
checkpoint = torch.load(path, map_location="cpu")
return is_state_nncf(checkpoint)
except FileNotFoundError:
return False
[docs]
class AccuracyAwareLrUpdater:
"""AccuracyAwareLrUpdater."""
def __init__(self, lr_hook):
self._lr_hook = lr_hook
self._lr_hook.warmup_iters = 0
[docs]
def step(self, *args, **kwargs):
"""step."""
@property
def base_lrs(self):
"""base_lrs."""
return self._lr_hook.base_lr
@base_lrs.setter
def base_lrs(self, value):
self._lr_hook.base_lr = value