diff --git a/chanfig/config.py b/chanfig/config.py index a885bb62..4efc5042 100755 --- a/chanfig/config.py +++ b/chanfig/config.py @@ -103,7 +103,7 @@ def frozen_check(func: Callable): @wraps(func) def decorator(self, *args, **kwargs): if self.getattr("frozen", False): - raise ValueError("Attempting to alter a frozen config. Run config.defrost() to defrost first") + raise ValueError("Attempting to alter a frozen config. Run config.defrost() to defrost first.") return func(self, *args, **kwargs) return decorator @@ -123,6 +123,46 @@ class Config(NestedDict): accessing anything that does not exist will create a new empty Config sub-attribute. It is recommended to call `config.freeze()` or `config.to(NestedDict)` to avoid this behavior. + + Attributes + ---------- + parser: ConfigParser = ConfigParser() + Parser for command line arguments. + frozen: bool = False + If `True`, the config is frozen and cannot be altered. + + Examples + -------- + ```python + >>> c = Config(**{"f.n": "chang"}) + >>> c.i.d = 1013 + >>> c.i.d + 1013 + >>> c.d.i + Config() + >>> c.freeze() + Config( + (f): Config( + (n): 'chang' + ) + (i): Config( + (d): 1013 + ) + (d): Config( + (i): Config() + ) + ) + >>> c.d.i = 1013 + Traceback (most recent call last): + ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first. + >>> c.d.e + Traceback (most recent call last): + KeyError: 'Config does not contain e' + >>> with c.unlocked(): + ... del c.d + >>> c.dict() + {'f': {'n': 'chang'}, 'i': {'d': 1013}} + """ parser: ConfigParser @@ -131,7 +171,9 @@ class Config(NestedDict): def __init__(self, *args, **kwargs): if not self.hasattr("default_mapping"): self.setattr("default_mapping", Config) - super().__init__(*args, default_factory=Config, **kwargs) + if "default_factory" not in kwargs: + kwargs["default_factory"] = Config + super().__init__(*args, **kwargs) self.setattr("parser", ConfigParser()) def get(self, name: str, default: Optional[Any] = None) -> Any: @@ -234,7 +276,7 @@ def set( {'i': {'d': 1013}} >>> c['i.d'] = 1013 Traceback (most recent call last): - ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first + ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first. >>> c.defrost().dict() {'i': {'d': 1013}} >>> c['i.d'] = 1013 @@ -316,7 +358,7 @@ def pop(self, name: str, default: Optional[Any] = None) -> Any: {'i': {}} >>> c['i.d'] = 1013 Traceback (most recent call last): - ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first + ValueError: Attempting to alter a frozen config. Run config.defrost() to defrost first. >>> c.defrost().dict() {'i': {}} >>> c['i.d'] = 1013 diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index 66455c3d..4b60bb33 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -12,6 +12,7 @@ from os import PathLike from os.path import splitext from typing import IO, Any, Callable, Iterable, Optional, Union +from warnings import warn from yaml import dump as yaml_dump from yaml import load as yaml_load @@ -51,6 +52,19 @@ class FlatDict(OrderedDict): `FlatDict` works best with `Variable` objects. + Note that since `FlatDict` supports attribute-style access to keys. + Therefore, all internal attributes should be set and get through `FlatDict.setattr` and `FlatDict.getattr`. + + `__class__`, `__dict__`, and `getattr` are reserved and cannot be override in any manner. + Although it is possible to override other internal methods, it is not recommended. + + Attributes + ---------- + indent: int + Indentation level in printing and dumping to json or yaml. + default_factory: Optional[Callable] + Default factory for defaultdict behavior. + Examples -------- ```python @@ -93,7 +107,7 @@ def __init__(self, *args, default_factory: Optional[Callable] = None, **kwargs) self.setattr("default_factory", default_factory) else: raise TypeError( - f"default_factory={default_factory} should be of type Callable, but got {type(default_factory)}" + f"default_factory={default_factory} should be of type Callable, but got {type(default_factory)}." ) self._init(*args, **kwargs) @@ -114,6 +128,11 @@ def _init(self, *args, **kwargs) -> None: for key, value in kwargs.items(): self.set(key, value) + def __getattribute__(self, name): + if name not in ("__class__", "__dict__", "getattr") and name in self: + return self[name] + return super().__getattribute__(name) + def get(self, name: str, default: Optional[Any] = None) -> Any: r""" Get value from FlatDict. @@ -155,7 +174,7 @@ def get(self, name: str, default: Optional[Any] = None) -> Any: 2 >>> d.get('f') Traceback (most recent call last): - KeyError: 'FlatDict does not contain f' + KeyError: 'FlatDict does not contain f.' ``` """ @@ -233,11 +252,11 @@ def delete(self, name: str) -> None: >>> d.delete('d') >>> d.d Traceback (most recent call last): - KeyError: 'FlatDict does not contain d' + KeyError: 'FlatDict does not contain d.' >>> del d.n >>> d.n Traceback (most recent call last): - KeyError: 'FlatDict does not contain n' + KeyError: 'FlatDict does not contain n.' >>> del d.f Traceback (most recent call last): KeyError: 'f' @@ -276,7 +295,7 @@ def getattr(self, name: str, default: Optional[Any] = None) -> Any: 2 >>> d.getattr('a') Traceback (most recent call last): - AttributeError: FlatDict has no attribute a + AttributeError: FlatDict has no attribute a. ``` """ @@ -290,7 +309,7 @@ def getattr(self, name: str, default: Optional[Any] = None) -> Any: except AttributeError: if default is not None: return default - raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") from None + raise AttributeError(f"{self.__class__.__name__} has no attribute {name}.") from None def setattr(self, name: str, value: Any) -> None: r""" @@ -303,6 +322,11 @@ def setattr(self, name: str, value: Any) -> None: name: str value: Any + Warns + ------ + RuntimeWarning + If name already exists in FlatDict. + Examples -------- ```python @@ -310,10 +334,24 @@ def setattr(self, name: str, value: Any) -> None: >>> d.setattr('attr', 'value') >>> d.getattr('attr') 'value' + >>> d.set('d', 1013) + >>> d.setattr('d', 1031) # Trigger RuntimeWarning: d already exists in FlatDict. + >>> d.get('d') + 1013 + >>> d.d + 1013 + >>> d.getattr('d') + 1031 ``` """ + if name in self: + warn( + f"{name} already exists in {self.__class__.__name__}.\n" + "Users must call `{self.__class__.__name__}.getattr()` to retrieve conflicting attribute value.", + RuntimeWarning, + ) self.__dict__[name] = value def hasattr(self, name: str) -> bool: @@ -365,7 +403,7 @@ def delattr(self, name: str) -> None: >>> d.delattr('name') >>> d.getattr('name') Traceback (most recent call last): - AttributeError: FlatDict has no attribute name + AttributeError: FlatDict has no attribute name. ``` """ @@ -373,37 +411,10 @@ def delattr(self, name: str) -> None: del self.__dict__[name] def __missing__(self, name: str, default: Optional[Any] = None) -> Any: - r""" - Allow dict to have default value if it doesn't exist. - - Parameters - ---------- - name: str - default: Optional[Any] = None - - Returns - ------- - value: Any - If name does not exist, return `default`. - - Examples - -------- - ```python - >>> d = FlatDict(default_factory=list) - >>> d.n - [] - >>> d.get('d', 1013) - 1013 - >>> d.__missing__('d', 1013) - 1013 - - ``` - """ - if default is None: # default_factory might not in __dict__ and cannot be replaced with if self.getattr("default_factory") if "default_factory" not in self.__dict__: - raise KeyError(f"{self.__class__.__name__} does not contain {name}") + raise KeyError(f"{self.__class__.__name__} does not contain {name}.") default_factory = self.getattr("default_factory") default = default_factory() if isinstance(default, FlatDict): @@ -488,7 +499,7 @@ def difference(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict: {'d': 4} >>> d.difference(1) Traceback (most recent call last): - TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got + TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got . ``` """ @@ -498,7 +509,7 @@ def difference(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict: if isinstance(other, (Mapping,)): other = other.items() if not isinstance(other, Iterable): - raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}") + raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}.") return self.empty_like( **{key: value for key, value in other if key not in self or self[key] != value} # type: ignore @@ -534,7 +545,7 @@ def intersection(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict: {'c': 3} >>> d.intersection(1) Traceback (most recent call last): - TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got + TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got . ``` """ @@ -544,7 +555,7 @@ def intersection(self, other: Union[Mapping, Iterable, PathStr]) -> FlatDict: if isinstance(other, (Mapping,)): other = other.items() if not isinstance(other, Iterable): - raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}") + raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}.") return self.empty_like( **{key: value for key, value in other if key in self and self[key] == value} # type: ignore ) @@ -715,7 +726,7 @@ def to(self, cls: Union[str, TorchDevice, TorchDtype]) -> FlatDict: self[k] = v.to(cls) return self - raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}") + raise TypeError(f"to() only support torch.dtype and torch.device, but got {cls}.") def cpu(self) -> FlatDict: r""" @@ -1014,7 +1025,7 @@ def dump(self, file: File, method: Optional[str] = None, *args, **kwargs) -> Non if method is None: if isinstance(file, IO): - raise ValueError("method must be specified when dumping to file-like object") + raise ValueError("method must be specified when dumping to file-like object.") method = splitext(file)[-1][1:] # type: ignore extension = method.lower() # type: ignore if extension in YAML: @@ -1044,14 +1055,14 @@ def load(cls, file: File, method: Optional[str] = None, *args, **kwargs) -> Flat if method is None: if isinstance(file, IO): - raise ValueError("method must be specified when loading from file-like object") + raise ValueError("method must be specified when loading from file-like object.") method = splitext(file)[-1][1:] # type: ignore extension = method.lower() # type: ignore if extension in JSON: return cls.from_json(file, *args, **kwargs) if extension in YAML: return cls.from_yaml(file, *args, **kwargs) - raise FileError("file {file} should be in {JSON} or {YAML}, but got {extension}") + raise FileError("file {file} should be in {JSON} or {YAML}, but got {extension}.") @staticmethod @contextmanager @@ -1069,28 +1080,13 @@ def open(file: File, *args, **kwargs): elif isinstance(file, (IO,)): yield file else: - raise TypeError( - f"file={file} should be of type (str, os.PathLike) or (io.IOBase), but got {type(file)}" # type: ignore - ) + raise TypeError(f"file={file!r} should be of type (str, os.PathLike) or (io.IOBase), but got {type(file)}.") @staticmethod def extra_repr() -> str: # pylint: disable=C0116 return "" def __repr__(self) -> str: - r""" - Representation of FlatDict. - - Examples - -------- - ```python - >>> d = FlatDict(a=1, b=2, c=3) - >>> repr(d) - 'FlatDict(\n (a): 1\n (b): 2\n (c): 3\n)' - - ``` - """ - extra_lines = [] extra_repr = self.extra_repr() # empty string will be split into list [''] diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index 9b5729e9..299e8879 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -23,6 +23,18 @@ class NestedDict(FlatDict): `NestedDict` also has `all_keys`, `all_values`, `all_items` methods to get all keys, values, items respectively in the nested structure. + When `convert_mapping` specified, all new values with a type of `Mapping` will be converted to `default_mapping`. + Note that `convert_mapping` is automatically applied to arguments at initialisation. + + Attributes + ---------- + default_mapping: Callable = NestedDict + Default mapping when performing `convert_mapping`. + convert_mapping: bool = False + If `True`, all new values with a type of `Mapping` will be converted to `default_mapping`. + delimiter: str = "." + Delimiter for nested structure. + Examples -------- ```python @@ -98,7 +110,7 @@ def get(self, name: str, default: Optional[Any] = None) -> Any: >>> d = NestedDict() >>> d.f Traceback (most recent call last): - KeyError: 'NestedDict does not contain f' + KeyError: 'NestedDict does not contain f.' ``` """ @@ -207,11 +219,11 @@ def delete(self, name: str) -> None: False >>> d.i.d Traceback (most recent call last): - KeyError: 'NestedDict does not contain d' + KeyError: 'NestedDict does not contain d.' >>> del d.f.n >>> d.f.n Traceback (most recent call last): - KeyError: 'NestedDict does not contain n' + KeyError: 'NestedDict does not contain n.' >>> del d.c Traceback (most recent call last): KeyError: 'c' @@ -422,7 +434,7 @@ def difference( # pylint: disable=W0221 {'d': 4} >>> d.difference(1) Traceback (most recent call last): - TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got + TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got . ``` """ @@ -434,7 +446,7 @@ def difference( # pylint: disable=W0221 if isinstance(other, (Mapping,)): other = self.empty_like(**other).items() if not isinstance(other, Iterable): - raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}") + raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}.") @wraps(self.difference) def difference(this: NestedDict, that: Iterable) -> Mapping: @@ -481,7 +493,7 @@ def intersection( # pylint: disable=W0221 {'a': 1} >>> d.intersection(1) Traceback (most recent call last): - TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got + TypeError: other=1 should be of type Mapping, Iterable or PathStr, but got . ``` """ @@ -491,7 +503,7 @@ def intersection( # pylint: disable=W0221 if isinstance(other, (Mapping,)): other = self.empty_like(**other).items() if not isinstance(other, Iterable): - raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}") + raise TypeError(f"other={other} should be of type Mapping, Iterable or PathStr, but got {type(other)}.") @wraps(self.intersection) def intersection(this: NestedDict, that: Iterable) -> Mapping: @@ -570,7 +582,7 @@ class DefaultDict(NestedDict): In addition, if you access a key that does not exist, the value will be set to `default_factory()`. """ - def __init__(self, *args, default_factory: Optional[Callable] = None, **kwargs): - if default_factory is None: - default_factory = NestedDict - super().__init__(*args, default_factory=default_factory, **kwargs) + def __init__(self, *args, **kwargs): + if "default_factory" not in kwargs: + kwargs["default_factory"] = NestedDict + super().__init__(*args, **kwargs) diff --git a/tests/test_config.py b/tests/test_config.py index d4ab6e31..41246558 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,6 +5,7 @@ class TestConfig(Config): __test__ = False def __init__(self, *args, **kwargs): + super().__init__() num_classes = Variable(10) self.name = "CHANfiG" self.seed = Variable(1013)