# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""This implementation replaces the functionality of mmengine utils."""
# TODO(someone): Revisit mypy errors after deprecation of mmlab
# mypy: ignore-errors
from __future__ import annotations
import copy
import os
import re
from collections import OrderedDict, abc, namedtuple
from pathlib import Path
from typing import Any, Iterator, Union
from warnings import warn
import numpy as np
import torch
from torch import distributed as torch_dist
from torch import nn
from torch.utils.model_zoo import load_url
[docs]
def get_dist_info() -> tuple[int, int]:
"""Get distributed information of the given process group.
Note:
Calling ``get_dist_info`` in non-distributed environment will return
(0, 1).
Returns:
tuple[int, int]: Return a tuple containing the ``rank`` and
``world_size``.
"""
if torch_dist.is_available() and torch_dist.is_initialized():
# handle low versions of torch like 1.5.0 which does not support
# passing in None for group argument
group = torch_dist.distributed_c10d._get_default_group() # noqa: SLF001
world_size = torch_dist.get_world_size(group)
rank = torch_dist.get_rank(group)
else:
world_size = 1
rank = 0
return rank, world_size
[docs]
def load_checkpoint(
model: nn.Module,
checkpoint: str,
map_location: str = "cpu",
strict: bool = False,
prefix: str = "",
) -> None:
"""Load state dict from path of checkpoint and dump to model."""
if Path(checkpoint).exists():
load_checkpoint_to_model(
model,
torch.load(checkpoint, map_location),
strict=strict,
prefix=prefix,
)
else:
load_checkpoint_to_model(
model,
load_from_http(checkpoint, map_location),
strict=strict,
prefix=prefix,
)
[docs]
def load_from_http(
filename: str,
map_location: str | None = None,
model_dir: Path | str | None = None,
progress: bool = os.isatty(0),
) -> dict[str, Any]:
"""Loads a checkpoint from an HTTP URL.
Copy of mmengine.runner.checkpoint.load_from_http.
Args:
filename (str): The URL of the checkpoint file.
map_location (str | None, optional): Specifies where to load the checkpoint onto.
Defaults to None.
model_dir (str | None, optional): The directory to save the downloaded checkpoint.
Defaults to None.
progress (bool, optional): Whether to display a progress bar while downloading the checkpoint.
Defaults to True if running in a terminal, otherwise False.
Returns:
dict[str, Any]: The loaded checkpoint.
Raises:
None
"""
rank, world_size = get_dist_info()
if rank == 0:
checkpoint = load_url(filename, model_dir=model_dir, map_location=map_location, progress=progress)
if world_size > 1:
torch_dist.barrier()
if rank > 0:
checkpoint = load_url(filename, model_dir=model_dir, map_location=map_location, progress=progress)
return checkpoint
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"
return super().__repr__()
__str__ = __repr__
[docs]
def load_state_dict(module: nn.Module, state_dict: OrderedDict, strict: bool = False) -> None:
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Defaults to False.
"""
unexpected_keys: list[str] = []
missing_keys: list[str] = []
err_msg: list[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata # noqa: SLF001
# use _load_from_state_dict to enable checkpoint version control
def load(module: nn.Module, local_state_dict: dict, prefix: str = "") -> None:
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict( # noqa: SLF001
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
err_msg,
)
for name, child in module._modules.items(): # noqa: SLF001
if child is not None:
child_prefix = prefix + name + "."
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
load(child, child_state_dict, child_prefix)
# Note that the hook can modify missing_keys and unexpected_keys.
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
if hasattr(module, "_load_state_dict_post_hooks"):
for hook in module._load_state_dict_post_hooks.values(): # noqa: SLF001
_ = hook(module, incompatible_keys)
load(module, state_dict)
load = None # break load->load reference cycle
# ignore "num_batches_tracked" of BN layers
missing_keys = [key for key in missing_keys if "num_batches_tracked" not in key]
if unexpected_keys:
err_msg.append(f'unexpected key in source state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
if strict:
raise RuntimeError("\n".join(err_msg))
warn("\n".join(err_msg), stacklevel=1)
[docs]
def load_checkpoint_to_model(
model: nn.Module,
checkpoint: dict,
strict: bool = False,
prefix: str = "",
) -> None:
"""Loads a checkpoint dictionary into a PyTorch model.
Copy of mmengine.runner.checkpoint._load_checkpoint_to_model.
Args:
model (nn.Module): The PyTorch model to load the checkpoint into.
checkpoint (dict): The checkpoint dictionary containing the model's state_dict.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys
in the model's state_dict. Defaults to False.
Returns:
None
"""
# get state_dict from checkpoint
# TODO(Kirill): remove this when RTDETR weights is updloaded to openvino storage.
state_dict = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint.get("state_dict", checkpoint)
# TODO(Eugene): remove this when MaskDINO weights is updloaded to openvino storage.
state_dict = state_dict.get("model", state_dict)
# strip prefix of state_dict
metadata = getattr(state_dict, "_metadata", OrderedDict())
for p, r in [(r"^module\.", ""), (rf"^{prefix}\.", "")]:
state_dict = OrderedDict({re.sub(p, r, k): v for k, v in state_dict.items()})
# Keep metadata in state_dict
state_dict._metadata = metadata # noqa: SLF001
# load state_dict
load_state_dict(model, state_dict, strict)
[docs]
def is_seq_of(
seq: Any, # noqa: ANN401
expected_type: type | tuple,
seq_type: type | None = None,
) -> bool:
"""Check whether it is a sequence of some type.
Copied from mmengine.utils.misc.is_seq_of
Args:
seq (Sequence): The sequence to be checked.
expected_type (type or tuple): Expected type of sequence items.
seq_type (type, optional): Expected sequence type. Defaults to None.
Returns:
bool: Return True if ``seq`` is valid else False.
Examples:
>>> from mmengine.utils import is_seq_of
>>> seq = ['a', 'b', 'c']
>>> is_seq_of(seq, str)
True
>>> is_seq_of(seq, int)
False
"""
exp_seq_type = abc.Sequence if seq_type is None else seq_type
if not isinstance(seq, exp_seq_type):
return False
return all(isinstance(item, expected_type) for item in seq)
[docs]
def is_tuple_of(seq: Any, expected_type: type | tuple) -> bool: # noqa: ANN401
"""Check whether it is a tuple of some type.
Copied from mmengine.utils.misc.is_tuple_of
A partial method of :func:`is_seq_of`.
"""
return is_seq_of(seq, expected_type, seq_type=tuple)
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray]
[docs]
class InstanceData:
"""A base data interface that supports Tensor-like and dict-like operations.
This class is from https://github.com/open-mmlab/mmengine/blob/66fb81f7b392b2cd304fc1979d8af3cc71a011f5/mmengine/structures/instance_data.py
and slightly modified.
Args:
metainfo (dict, optional): A dict contains the meta information
of single image, such as ``dict(img_shape=(512, 512, 3),
scale_factor=(1, 1, 1, 1))``. Defaults to None.
kwargs (dict, optional): A dict contains annotations of single image or
model predictions. Defaults to None.
"""
def __init__(self, *, metainfo: dict | None = None, **kwargs) -> None:
self._metainfo_fields: set = set()
self._data_fields: set = set()
if metainfo is not None:
self.set_metainfo(metainfo=metainfo)
if kwargs:
self.set_data(kwargs)
[docs]
def set_data(self, data: dict) -> None:
"""Set or change key-value pairs in ``data_field`` by parameter ``data``.
Args:
data (dict): A dict contains annotations of image or
model predictions.
"""
for k, v in data.items():
# Use `setattr()` rather than `self.set_field` to allow `set_data`
# to set property method.
setattr(self, k, v)
[docs]
def update(self, instance: InstanceData) -> None:
"""The method updates the InstanceData with the elements from another InstanceData object.
Args:
instance (InstanceData): Another InstanceData object for
update the current object.
"""
self.set_metainfo(dict(instance.metainfo_items()))
self.set_data(dict(instance.items()))
[docs]
def new(self, *, metainfo: dict | None = None, **kwargs) -> InstanceData:
"""Return a new data element with same type.
If ``metainfo`` and ``data`` are None, the new data element will have same metainfo and
data. If metainfo or data is not None, the new result will overwrite it
with the input value.
Args:
metainfo (dict, optional): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc.
Defaults to None.
kwargs (dict): A dict contains annotations of image or
model predictions.
Returns:
InstanceData: A new data element with same type.
"""
new_data = self.__class__()
if metainfo is not None:
new_data.set_metainfo(metainfo)
else:
new_data.set_metainfo(dict(self.metainfo_items()))
if kwargs:
new_data.set_data(kwargs)
else:
new_data.set_data(dict(self.items()))
return new_data
[docs]
def clone(self) -> InstanceData:
"""Deep copy the current data element.
Returns:
InstanceData: The copy of current data element.
"""
clone_data = self.__class__()
clone_data.set_metainfo(dict(self.metainfo_items()))
clone_data.set_data(dict(self.items()))
return clone_data
[docs]
def keys(self) -> list:
"""Returns lits contains all keys in data_fields."""
private_keys = {"_" + key for key in self._data_fields if isinstance(getattr(type(self), key, None), property)}
return list(self._data_fields - private_keys)
[docs]
def values(self) -> list:
"""Returns list contains all values in data."""
return [getattr(self, k) for k in self.keys()]
[docs]
def all_keys(self) -> list:
"""Returns list contains all keys in metainfo and data."""
return self.metainfo_keys() + self.keys()
[docs]
def all_values(self) -> list:
"""Returns list contains all values in metainfo and data."""
return self.metainfo_values() + self.values()
[docs]
def all_items(self) -> Iterator[tuple[str, Any]]:
"""Returns iterator object whose element is (key, value) tuple pairs for ``metainfo`` and ``data``."""
for k in self.all_keys():
yield (k, getattr(self, k))
[docs]
def items(self) -> Iterator[tuple[str, Any]]:
"""Returns iterator object whose element is (key, value) tuple pairs for ``data``."""
for k in self.keys():
yield (k, getattr(self, k))
@property
def metainfo(self) -> dict:
"""dict: A dict contains metainfo of current data element."""
return dict(self.metainfo_items())
def __setattr__(self, name: str, value: Any): # noqa: ANN401
"""Setattr is only used to set data."""
if name in ("_metainfo_fields", "_data_fields"):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
msg = f"{name} has been used as a private attribute, which is immutable."
raise AttributeError(msg)
else:
self.set_field(name=name, value=value, field_type="data", dtype=None)
__setitem__ = __setattr__
def __getitem__(self, item: IndexType) -> InstanceData:
"""Get item mehod.
Args:
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
Get the corresponding values according to item.
Returns:
:obj:`InstanceData`: Corresponding values.
"""
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
# The default int type of numpy is platform dependent, int32 for
# windows and int64 for linux. `torch.Tensor` requires the index
# should be int64, therefore we simply convert it to int64 here.
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
if isinstance(item, str):
return getattr(self, item)
if isinstance(item, int):
if item >= len(self) or item < -len(self):
msg = f"Index {item} out of range!"
raise IndexError(msg)
item = slice(item, None, len(self))
new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, torch.Tensor):
for k, v in self.items():
if isinstance(v, torch.Tensor):
new_data[k] = v[item]
elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()]
elif isinstance(v, (str, list, tuple)) or (hasattr(v, "__getitem__") and hasattr(v, "cat")):
# convert to indexes from BoolTensor
if isinstance(item, BoolTypeTensor.__args__):
indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist()
else:
indexes = item.cpu().numpy().tolist()
slice_list = []
if indexes:
for index in indexes:
slice_list.append(slice(index, None, len(v))) # noqa: PERF401
else:
slice_list.append(slice(None, 0, None))
r_list = [v[s] for s in slice_list]
if isinstance(v, (str, list, tuple)):
new_value = r_list[0]
for r in r_list[1:]:
new_value = new_value + r
else:
new_value = v.cat(r_list)
new_data[k] = new_value
else:
msg = (
f"The type of `{k}` is `{type(v)}`, "
"which has no attribute of `cat`, so it does not support slice with `bool`"
)
raise ValueError(msg)
else:
# item is a slice
for k, v in self.items():
new_data[k] = v[item]
return new_data
def __delattr__(self, item: str):
"""Delete the item in dataelement.
Args:
item (str): The key to delete.
"""
if item in ("_metainfo_fields", "_data_fields"):
msg = f"{item} has been used as a private attribute, which is immutable."
raise AttributeError(msg)
super().__delattr__(item)
if item in self._metainfo_fields:
self._metainfo_fields.remove(item)
elif item in self._data_fields:
self._data_fields.remove(item)
# dict-like methods
__delitem__ = __delattr__
[docs]
def get(self, key: str, default: Any | None = None) -> Any: # noqa: ANN401
"""Get property in data and metainfo as the same as python."""
# Use `getattr()` rather than `self.__dict__.get()` to allow getting
# properties.
return getattr(self, key, default)
[docs]
def pop(self, *args) -> Any: # noqa: ANN401
"""Pop property in data and metainfo as the same as python."""
name = args[0]
if name in self._metainfo_fields:
self._metainfo_fields.remove(args[0])
return self.__dict__.pop(*args)
if name in self._data_fields:
self._data_fields.remove(args[0])
return self.__dict__.pop(*args)
# with default value
if len(args) == 2:
return args[1]
msg = f"{args[0]} is not contained in metainfo or data"
raise KeyError(msg)
def __contains__(self, item: str) -> bool:
"""Whether the item is in dataelement.
Args:
item (str): The key to inquire.
"""
return item in self._data_fields or item in self._metainfo_fields
[docs]
def set_field(
self,
value: Any, # noqa: ANN401
name: str,
dtype: type | tuple[type, ...] | None = None,
field_type: str = "data",
) -> None:
"""Special method for set union field, used as property.setter functions."""
if field_type == "metainfo":
if name in self._data_fields:
msg = f"Cannot set {name} to be a field of metainfo because {name} is already a data field"
raise AttributeError(msg)
self._metainfo_fields.add(name)
else:
if name in self._metainfo_fields:
msg = f"Cannot set {name} to be a field of data because {name} is already a metainfo field"
raise AttributeError(msg)
self._data_fields.add(name)
super().__setattr__(name, value)
# Tensor-like methods
[docs]
def to(self, *args, **kwargs) -> InstanceData:
"""Apply same name function to all tensors in data_fields."""
new_data = self.new()
for k, v in self.items():
if hasattr(v, "to"):
v = v.to(*args, **kwargs) # noqa: PLW2901
data = {k: v}
new_data.set_data(data)
return new_data
# Tensor-like methods
[docs]
def cpu(self) -> InstanceData:
"""Convert all tensors to CPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, InstanceData)):
v = v.cpu() # noqa: PLW2901
data = {k: v}
new_data.set_data(data)
return new_data
# Tensor-like methods
[docs]
def cuda(self) -> InstanceData:
"""Convert all tensors to GPU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, InstanceData)):
v = v.cuda() # noqa: PLW2901
data = {k: v}
new_data.set_data(data)
return new_data
# Tensor-like methods
[docs]
def detach(self) -> InstanceData:
"""Detach all tensors in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, InstanceData)):
v = v.detach() # noqa: PLW2901
data = {k: v}
new_data.set_data(data)
return new_data
# Tensor-like methods
[docs]
def numpy(self) -> InstanceData:
"""Convert all tensors to np.ndarray in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, InstanceData)):
v = v.detach().cpu().numpy() # noqa: PLW2901
data = {k: v}
new_data.set_data(data)
return new_data
[docs]
def to_tensor(self) -> InstanceData:
"""Convert all np.ndarray to tensor in data."""
new_data = self.new()
for k, v in self.items():
data = {}
if isinstance(v, np.ndarray):
v = torch.from_numpy(v) # noqa: PLW2901
data[k] = v
elif isinstance(v, InstanceData):
v = v.to_tensor() # noqa: PLW2901
data[k] = v
new_data.set_data(data)
return new_data
[docs]
def to_dict(self) -> dict:
"""Convert InstanceData to dict."""
return {k: v.to_dict() if isinstance(v, InstanceData) else v for k, v in self.all_items()}
def __repr__(self) -> str:
"""Represent the object."""
def _addindent(s_: str, num_spaces: int) -> str:
"""This func is modified from `pytorch`.
https://github.com/pytorch/
pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu
les/module.py#L29.
Args:
s_ (str): The string to add spaces.
num_spaces (int): The num of space to add.
Returns:
str: The string after add indent.
"""
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
return first + "\n" + s
def dump(obj: Any) -> str: # noqa: ANN401
"""Represent the object.
Args:
obj (Any): The obj to represent.
Returns:
str: The represented str.
"""
_repr = ""
if isinstance(obj, dict):
for k, v in obj.items():
_repr += f"\n{k}: {_addindent(dump(v), 4)}"
elif isinstance(obj, InstanceData):
_repr += "\n\n META INFORMATION"
metainfo_items = dict(obj.metainfo_items())
_repr += _addindent(dump(metainfo_items), 4)
_repr += "\n\n DATA FIELDS"
items = dict(obj.items())
_repr += _addindent(dump(items), 4)
classname = obj.__class__.__name__
_repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>"
else:
_repr += repr(obj)
return _repr
return dump(self)
def __len__(self) -> int:
"""int: The length of InstanceData."""
if len(self._data_fields) > 0:
return len(self.values()[0])
return 0