Source code for datumaro.components.config

# Copyright (C) 2019-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from typing import IO, Union

import yaml

from datumaro.components.errors import ImmutableObjectError


[docs] class Schema:
[docs] class Item: def __init__(self, ctor, internal=False): self.ctor = ctor self.internal = internal def __call__(self, *args, **kwargs): return self.ctor(*args, **kwargs)
def __init__(self, items=None, fallback=None): self._items = {} if items is not None: self._items.update(items) self._fallback = fallback def _get_items(self, allow_fallback=True): all_items = {} if allow_fallback and self._fallback is not None: all_items.update(self._fallback) all_items.update(self._items) return all_items
[docs] def items(self, allow_fallback=True): return self._get_items(allow_fallback=allow_fallback).items()
[docs] def keys(self, allow_fallback=True): return self._get_items(allow_fallback=allow_fallback).keys()
[docs] def values(self, allow_fallback=True): return self._get_items(allow_fallback=allow_fallback).values()
def __contains__(self, key): return key in self.keys() def __len__(self): return len(self._get_items()) def __iter__(self): return iter(self._get_items()) def __getitem__(self, key): default = object() value = self.get(key, default=default) if value is default: raise KeyError('Key "%s" does not exist' % (key,)) return value
[docs] def get(self, key, default=None): found = self._items.get(key, default) if found is not default: return found if self._fallback is not None: return self._fallback.get(key, default) return found
[docs] class SchemaBuilder: def __init__(self): self._items = {}
[docs] def add(self, name, ctor=str, internal=False): if name in self._items: raise KeyError('Key "%s" already exists' % (name,)) self._items[name] = Schema.Item(ctor, internal=internal) return self
[docs] def build(self): return Schema(self._items)
[docs] class Config: def __init__(self, config=None, fallback=None, schema=None, mutable=True): # schema should be established first self.__dict__["_schema"] = schema self.__dict__["_mutable"] = True self.__dict__["_config"] = {} if fallback is not None: for k, v in fallback.items(allow_fallback=False): self.set(k, v) if config is not None: self.update(config) self.__dict__["_mutable"] = mutable def _items(self, allow_fallback=True, allow_internal=True): all_config = {} if allow_fallback and self._schema is not None: for key, item in self._schema.items(): all_config[key] = item() all_config.update(self._config) if not allow_internal and self._schema is not None: for key, item in self._schema.items(): if item.internal: all_config.pop(key, None) return all_config
[docs] def items(self, allow_fallback=True, allow_internal=True): return self._items(allow_fallback=allow_fallback, allow_internal=allow_internal).items()
[docs] def keys(self, allow_fallback=True, allow_internal=True): return self._items(allow_fallback=allow_fallback, allow_internal=allow_internal).keys()
[docs] def values(self, allow_fallback=True, allow_internal=True): return self._items(allow_fallback=allow_fallback, allow_internal=allow_internal).values()
def __contains__(self, key): return key in self.keys() def __len__(self): return len(self.items()) def __iter__(self): return iter(self.keys()) def __getitem__(self, key): default = object() value = self.get(key, default=default) if value is default: raise KeyError('Key "%s" does not exist' % (key,)) return value def __setitem__(self, key, value): return self.set(key, value) def __getattr__(self, key): default = object() found = self.get(key, default=default) if found is default: raise AttributeError(key) return found def __setattr__(self, key, value): return self.set(key, value) def __eq__(self, other): try: keys = set(self.keys()) | set(other.keys()) for k in keys: my_v = self[k] other_v = other[k] if my_v != other_v: return False return True except (KeyError, AttributeError): return False def __repr__(self): return repr(dict(self))
[docs] def update(self, other): for k, v in other.items(): self.set(k, v)
[docs] def remove(self, key): if not self._mutable: raise ImmutableObjectError() self._config.pop(key, None)
[docs] def get(self, key, default=None): found = self._config.get(key, default) if found is not default: return found if self._schema is not None: found = self._schema.get(key, default) if found is not default: # ignore mutability found = found() self._config[key] = found return found return found
[docs] def set(self, key, value): if not self._mutable: raise ImmutableObjectError() if self._schema is not None: if key not in self._schema: raise KeyError("Can not set key '%s' - schema mismatch: " "unknown key" % (key,)) schema_entry = self._schema[key] schema_entry_instance = schema_entry() if isinstance(value, (dict, Config)) and isinstance(schema_entry_instance, Config): schema_entry_instance.update(value) value = schema_entry_instance elif not isinstance(value, type(schema_entry_instance)): raise ValueError( "Can not set key '%s' - schema mismatch:" "unexpected value type %s, expected %s" % (key, type(value), type(schema_entry_instance)) ) self._config[key] = value return value
[docs] @classmethod def parse(cls, path: Union[str, IO], *args, **kwargs): if isinstance(path, str): with open(path, "r", encoding="utf-8") as f: return cls(yaml.safe_load(f), *args, **kwargs) else: return cls(yaml.safe_load(path), *args, **kwargs)
[docs] @staticmethod def yaml_representer(dumper, value): return dumper.represent_data(value._items(allow_internal=False, allow_fallback=False))
[docs] def dump(self, path: Union[str, IO]): if isinstance(path, str): with open(path, "w", encoding="utf-8") as f: yaml.safe_dump(self, f) else: yaml.safe_dump(self, path)
yaml.add_multi_representer(Config, Config.yaml_representer, Dumper=yaml.SafeDumper) yaml.add_multi_representer( tuple, lambda dumper, value: dumper.represent_data(list(value)), Dumper=yaml.SafeDumper )
[docs] class DictConfig(Config): def __init__(self, default=None): super().__init__() self.__dict__["_default"] = default
[docs] def set(self, key, value): if self._default is not None: schema_entry_instance = self._default(value) if isinstance(value, (dict, Config)) and isinstance(schema_entry_instance, Config): value = schema_entry_instance elif not isinstance(value, type(schema_entry_instance)): raise ValueError( "Can not set key '%s' - schema mismatch:" "unexpected value type %s, expected %s" % (key, type(value), type(schema_entry_instance)) ) return super().set(key, value)