From 53fd169c9bae2b7f58a9463d6b3d4af54a76aae4 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Mon, 11 Sep 2023 09:13:50 +0000 Subject: [PATCH] support fallback in NestedDict.get Signed-off-by: Zhiyuan Chen --- chanfig/config.py | 4 ++-- chanfig/flat_dict.py | 2 +- chanfig/nested_dict.py | 28 ++++++++++++++++++++++++---- chanfig/registry.py | 4 ++-- tests/test_nested_dict.py | 5 +++++ 5 files changed, 34 insertions(+), 9 deletions(-) diff --git a/chanfig/config.py b/chanfig/config.py index 4fcbee21..dcb8dfb8 100755 --- a/chanfig/config.py +++ b/chanfig/config.py @@ -428,7 +428,7 @@ def unlocked(self): if was_frozen: self.freeze() - def get(self, name: Any, default: Any = None) -> Any: + def get(self, name: Any, default: Any = None, fallback: bool | None = None) -> Any: r""" Get value from `Config`. @@ -476,7 +476,7 @@ def get(self, name: Any, default: Any = None) -> Any: if not self.hasattr("default_factory"): # did not call super().__init__() in sub-class self.setattr("default_factory", Config) if name in self or not self.getattr("frozen", False): - return super().get(name, default) + return super().get(name, default, fallback) raise KeyError(name) @frozen_check diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index df826b5b..58929f88 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -716,7 +716,7 @@ def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> FlatDict: def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping: if not that: return this - elif isinstance(that, Mapping): + if isinstance(that, Mapping): that = that.items() for key, value in that: if key in this and isinstance(this[key], Mapping): diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index 02a48c0b..66bceca1 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -149,13 +149,20 @@ class NestedDict(DefaultDict): # type: ignore # pylint: disable=E1136 convert_mapping: bool = False delimiter: str = "." + fallback: bool = False def __init__( - self, *args: Any, default_factory: Callable | None = None, convert_mapping: bool = False, **kwargs: Any + self, + *args: Any, + default_factory: Callable | None = None, + convert_mapping: bool = False, + fallback: bool = False, + **kwargs: Any, ) -> None: super().__init__(default_factory) self.merge(*args, **kwargs) self.setattr("convert_mapping", convert_mapping) + self.setattr("fallback", fallback) def all_keys(self) -> Generator: r""" @@ -289,7 +296,7 @@ def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict: apply_(self, func, *args, **kwargs) return self - def get(self, name: Any, default: Any = None) -> Any: + def get(self, name: Any, default: Any = None, fallback: bool | None = None) -> Any: r""" Get value from `NestedDict`. @@ -320,6 +327,7 @@ def get(self, name: Any, default: Any = None) -> Any: 1013 >>> d.get('f', 2) 2 + >>> d.get('a.b', None) >>> d.f NestedDict(, ) >>> del d.f @@ -340,12 +348,24 @@ def get(self, name: Any, default: Any = None) -> Any: """ delimiter = self.getattr("delimiter", ".") + if fallback is None: + fallback = self.getattr("fallback", False) + fallback_name = name.split(delimiter)[-1] if isinstance(name, str) else name + fallback_values = [] try: while isinstance(name, str) and delimiter in name: + if fallback and fallback_name in self: + fallback_values.append(self.get(fallback_name)) name, rest = name.split(delimiter, 1) self, name = self[name], rest # pylint: disable=W0642 - except (AttributeError, TypeError): + except (KeyError, AttributeError, TypeError): + if fallback and fallback_values: + return fallback_values[-1] + if default is not Null: + return default raise KeyError(name) from None + if (fallback and fallback_values) and (not isinstance(self, Iterable) or name not in self): + return fallback_values[-1] # if value is a python dict if not isinstance(self, NestedDict): if name not in self and default is not Null: @@ -405,9 +425,9 @@ def set( # pylint: disable=W0221 # pylint: disable=W0642 full_name = name + delimiter = self.getattr("delimiter", ".") if convert_mapping is None: convert_mapping = self.getattr("convert_mapping", False) - delimiter = self.getattr("delimiter", ".") default_factory = self.getattr("default_factory", self.empty) try: while isinstance(name, str) and delimiter in name: diff --git a/chanfig/registry.py b/chanfig/registry.py index 406d726b..e912e331 100644 --- a/chanfig/registry.py +++ b/chanfig/registry.py @@ -88,8 +88,8 @@ class Registry(NestedDict): # type: ignore override: bool = False - def __init__(self, override: bool = False): - super().__init__() + def __init__(self, override: bool = False, fallback: bool = False): + super().__init__(fallback=fallback) self.setattr("override", override) def register(self, component: Any = None, name: Any | None = None) -> Callable: diff --git a/tests/test_nested_dict.py b/tests/test_nested_dict.py index 53c8047d..cf58b988 100644 --- a/tests/test_nested_dict.py +++ b/tests/test_nested_dict.py @@ -86,6 +86,11 @@ def test_dropnull(self): } assert d.dropnull().dict() == {"h": {"j": 1}} + def test_fallback(self): + d = NestedDict({"n.d": 0.5, "n.a.d": 0.1, "n.b.l": 6}) + assert d.get("n.a.d", fallback=True) == 0.1 + assert d.get("n.b.d", fallback=True) == 0.5 + class ConfigDict(NestedDict): def __init__(self):