Source code for datumaro.util.scope

# Copyright (C) 2021-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

import threading
from contextlib import ExitStack, contextmanager
from functools import partial, wraps
from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, TypeVar

from attrs import frozen

from datumaro.util import optional_arg_decorator

T = TypeVar("T")


[docs] class Scope: """ A context manager that allows to register error and exit callbacks. """ _thread_locals = threading.local() @frozen class _ExitHandler: callback: Callable[[], Any] ignore_errors: bool = True def __exit__(self, exc_type, exc_value, exc_traceback): try: self.callback() except Exception: if not self.ignore_errors: raise @frozen class _ErrorHandler(_ExitHandler): def __exit__(self, exc_type, exc_value, exc_traceback): if exc_type: return super().__exit__( exc_type=exc_type, exc_value=exc_value, exc_traceback=exc_traceback ) def __init__(self): self._stack = ExitStack() self.enabled = True
[docs] def on_error_do( self, callback: Callable, *args, kwargs: Optional[Dict[str, Any]] = None, ignore_errors: bool = False, ): """ Registers a function to be called on scope exit because of an error. If ignore_errors is True, the errors from this function call will be ignored. """ self._register_callback( self._ErrorHandler, ignore_errors=ignore_errors, callback=callback, args=args, kwargs=kwargs, )
[docs] def on_exit_do( self, callback: Callable, *args, kwargs: Optional[Dict[str, Any]] = None, ignore_errors: bool = False, ): """ Registers a function to be called on scope exit. """ self._register_callback( self._ExitHandler, ignore_errors=ignore_errors, callback=callback, args=args, kwargs=kwargs, )
def _register_callback( self, handler_type, callback: Callable, args: Tuple[Any] = None, kwargs: Dict[str, Any] = None, ignore_errors: bool = False, ): if args or kwargs: callback = partial(callback, *args, **(kwargs or {})) self._stack.push(handler_type(callback, ignore_errors=ignore_errors))
[docs] def add(self, cm: ContextManager[T]) -> T: """ Enters a context manager and adds it to the exit stack. Returns: cm.__enter__() result """ return self._stack.enter_context(cm)
[docs] def enable(self): self.enabled = True
[docs] def disable(self): self.enabled = False
[docs] def close(self): self.__exit__(None, None, None)
def __enter__(self) -> Scope: return self def __exit__(self, exc_type, exc_value, exc_traceback): if not self.enabled: return self._stack.__exit__(exc_type, exc_value, exc_traceback) self._stack.pop_all() # prevent issues on repetitive calls
[docs] @classmethod def current(cls) -> Scope: return cls._thread_locals.current
[docs] @contextmanager def as_current(self): previous = getattr(self._thread_locals, "current", None) self._thread_locals.current = self try: yield finally: self._thread_locals.current = previous
[docs] @optional_arg_decorator def scoped(func, arg_name=None): """ A function decorator, which allows to do actions with the current scope, such as registering error and exit callbacks and context managers. """ @wraps(func) def wrapped_func(*args, **kwargs): with Scope() as scope: if arg_name is None: with scope.as_current(): ret_val = func(*args, **kwargs) else: kwargs[arg_name] = scope ret_val = func(*args, **kwargs) return ret_val return wrapped_func
# Shorthands for common cases
[docs] def on_error_do(callback, *args, ignore_errors=False, kwargs=None): return Scope.current().on_error_do(callback, *args, ignore_errors=ignore_errors, kwargs=kwargs)
[docs] def on_exit_do(callback, *args, ignore_errors=False, kwargs=None): return Scope.current().on_exit_do(callback, *args, ignore_errors=ignore_errors, kwargs=kwargs)
[docs] def scope_add(cm: ContextManager[T]) -> T: return Scope.current().add(cm)