"""Simple monkey patch helper."""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# pylint: disable=unnecessary-dunder-call,invalid-name
import ctypes
import importlib
import inspect
from collections import OrderedDict
from functools import partial, partialmethod
from typing import Callable
[docs]
class Patcher:
"""Simple monkey patch helper."""
def __init__(self):
self._patched = OrderedDict()
[docs]
def patch( # noqa: C901
self,
obj_cls,
wrapper: Callable,
*,
force: bool = True,
):
"""Do monkey patch."""
if isinstance(obj_cls, (tuple, list)):
assert len(obj_cls) == 2
obj_cls, fn_name = obj_cls
assert getattr(obj_cls, fn_name)
else:
obj_cls, fn_name = self.import_obj(obj_cls)
# wrap only if function does exist
n_args = len(inspect.getfullargspec(obj_cls.__getattribute__)[0])
if n_args == 1:
try:
fn = obj_cls.__getattribute__(fn_name)
except AttributeError:
return
self._patch_module_fn(obj_cls, fn_name, fn, wrapper, force)
elif inspect.isclass(obj_cls):
try:
fn = obj_cls.__getattribute__(obj_cls, fn_name) # type: ignore
except AttributeError:
return
self._patch_class_fn(obj_cls, fn_name, fn, wrapper, force)
else:
try:
fn = obj_cls.__getattribute__(fn_name)
except AttributeError:
return
self._patch_instance_fn(obj_cls, fn_name, fn, wrapper, force)
[docs]
def unpatch(self, obj_cls=None, depth=0):
"""Undo monkey patch."""
def _unpatch(obj, fn_name, key, depth):
if depth == 0:
depth = len(self._patched[key])
keep = len(self._patched[key]) - depth
origin_fn = self._patched[key].pop(-depth)[0]
while self._patched[key] and len(self._patched[key]) > keep:
self._patched[key].pop()
if not self._patched[key]:
self._patched.pop(key)
if isinstance(obj, int):
obj = ctypes.cast(obj, ctypes.py_object).value
setattr(obj, fn_name, origin_fn)
if obj_cls is not None:
obj_cls, fn_name = self.import_obj(obj_cls)
n_args = len(inspect.getfullargspec(obj_cls.__getattribute__)[0])
if n_args == 1:
key = (obj_cls.__name__, fn_name)
elif inspect.isclass(obj_cls):
obj_cls_path = obj_cls.__module__ + "." + obj_cls.__name__
key = (obj_cls_path, fn_name)
else:
key = (id(obj_cls), fn_name)
_unpatch(obj_cls, fn_name, key, depth)
return
for key in list(self._patched.keys()):
obj, fn_name = key
if isinstance(obj, int):
obj = ctypes.cast(obj, ctypes.py_object).value
else:
obj, fn_name = self.import_obj(".".join([obj, fn_name]))
_unpatch(obj, fn_name, key, depth)
[docs]
def import_obj(self, obj_cls): # noqa: C901
"""Object import helper."""
if isinstance(obj_cls, str):
fn_name = obj_cls.split(".")[-1]
obj_cls = ".".join(obj_cls.split(".")[:-1])
else:
if "_partialmethod" in obj_cls.__dict__:
while "_partialmethod" in obj_cls.__dict__:
obj_cls = obj_cls._partialmethod.keywords["__fn"] # pylint: disable=protected-access
while isinstance(obj_cls, (partial, partialmethod)):
obj_cls = obj_cls.keywords["__fn"]
if inspect.ismodule(obj_cls):
fn = obj_cls.keywords["__fn"]
fn_name = fn.__name__
obj_cls = fn = obj_cls.keywords["__obj_cls"]
elif inspect.ismethod(obj_cls):
fn_name = obj_cls.__name__
obj_cls = obj_cls.__self__
elif isinstance(obj_cls, (staticmethod, classmethod)):
obj_cls = obj_cls.__func__
fn_name = obj_cls.__name__
obj_cls = ".".join([obj_cls.__module__] + obj_cls.__qualname__.split(".")[:-1])
else:
fn_name = obj_cls.__name__
obj_cls = ".".join([obj_cls.__module__] + obj_cls.__qualname__.split(".")[:-1])
if isinstance(obj_cls, str):
try:
obj_cls = importlib.import_module(obj_cls)
except ModuleNotFoundError:
module = ".".join(obj_cls.split(".")[:-1])
obj_cls = obj_cls.split(".")[-1]
obj_cls = getattr(importlib.import_module(module), obj_cls)
return obj_cls, fn_name
def _patch_module_fn(self, obj_cls, fn_name, fn, wrapper, force):
def helper(*args, **kwargs): # type: ignore
obj_cls = kwargs.pop("__obj_cls")
fn = kwargs.pop("__fn")
wrapper = kwargs.pop("__wrapper")
return wrapper(obj_cls, fn, *args, **kwargs)
assert len(inspect.getfullargspec(obj_cls.__getattribute__)[0]) == 1
obj_cls_path = obj_cls.__name__
key = (obj_cls_path, fn_name)
fn_ = self._initialize(key, force)
if fn_ is not None:
fn = fn_
setattr(obj_cls, fn_name, partial(helper, __wrapper=wrapper, __fn=fn, __obj_cls=obj_cls))
self._patched[key].append((fn, wrapper))
def _patch_class_fn(self, obj_cls, fn_name, fn, wrapper, force):
if isinstance(fn, (staticmethod, classmethod)):
def helper(*args, **kwargs): # type: ignore
wrapper = kwargs.pop("__wrapper")
fn = kwargs.pop("__fn")
obj_cls = kwargs.pop("__obj_cls")
if isinstance(args[0], obj_cls):
return wrapper(args[0], fn.__get__(args[0]), *args[1:], **kwargs)
return wrapper(obj_cls, fn.__get__(obj_cls), *args, **kwargs)
elif isinstance(fn, type(all.__call__)):
def helper(self, *args, **kwargs): # type: ignore
kwargs.pop("__obj_cls")
wrapper = kwargs.pop("__wrapper")
fn = kwargs.pop("__fn")
return wrapper(self, fn, *args, **kwargs)
else:
def helper(self, *args, **kwargs): # type: ignore
kwargs.pop("__obj_cls")
wrapper = kwargs.pop("__wrapper")
fn = kwargs.pop("__fn")
return wrapper(self, fn.__get__(self), *args, **kwargs)
assert len(inspect.getfullargspec(obj_cls.__getattribute__)[0]) == 2
obj_cls_path = obj_cls.__module__ + "." + obj_cls.__name__
key = (obj_cls_path, fn_name)
fn_ = self._initialize(key, force)
if fn_ is not None:
fn = fn_
setattr(
obj_cls,
fn_name,
partialmethod(helper, __wrapper=wrapper, __fn=fn, __obj_cls=obj_cls),
)
self._patched[key].append((fn, wrapper))
def _patch_instance_fn(self, obj_cls, fn_name, fn, wrapper, force):
def helper(ctx, *args, **kwargs): # type: ignore
fn = kwargs.pop("__fn")
wrapper = kwargs.pop("__wrapper")
return wrapper(ctx, fn, *args, **kwargs)
assert len(inspect.getfullargspec(obj_cls.__getattribute__)[0]) == 2
obj_cls_path = id(obj_cls)
key = (obj_cls_path, fn_name)
fn_ = self._initialize(key, force)
if fn_ is not None:
fn = fn_
setattr(obj_cls, fn_name, partialmethod(helper, __wrapper=wrapper, __fn=fn).__get__(obj_cls))
self._patched[key].append((fn, wrapper))
def _initialize(self, key, force):
fn = None
if key not in self._patched:
self._patched[key] = []
if force:
while self._patched[key]:
fn, *_ = self._patched[key].pop()
return fn