Skip to content

Commit

Permalink
support fallback in NestedDict.get
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Sep 14, 2023
1 parent 03211bf commit 53fd169
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 9 deletions.
4 changes: 2 additions & 2 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def unlocked(self):
if was_frozen:
self.freeze()

def get(self, name: Any, default: Any = None) -> Any:
def get(self, name: Any, default: Any = None, fallback: bool | None = None) -> Any:
r"""
Get value from `Config`.
Expand Down Expand Up @@ -476,7 +476,7 @@ def get(self, name: Any, default: Any = None) -> Any:
if not self.hasattr("default_factory"): # did not call super().__init__() in sub-class
self.setattr("default_factory", Config)
if name in self or not self.getattr("frozen", False):
return super().get(name, default)
return super().get(name, default, fallback)
raise KeyError(name)

@frozen_check
Expand Down
2 changes: 1 addition & 1 deletion chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def merge(self, *args: Any, overwrite: bool = True, **kwargs: Any) -> FlatDict:
def _merge(this: FlatDict, that: Iterable, overwrite: bool = True) -> Mapping:
if not that:
return this
elif isinstance(that, Mapping):
if isinstance(that, Mapping):
that = that.items()
for key, value in that:
if key in this and isinstance(this[key], Mapping):
Expand Down
28 changes: 24 additions & 4 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,20 @@ class NestedDict(DefaultDict): # type: ignore # pylint: disable=E1136

convert_mapping: bool = False
delimiter: str = "."
fallback: bool = False

def __init__(
self, *args: Any, default_factory: Callable | None = None, convert_mapping: bool = False, **kwargs: Any
self,
*args: Any,
default_factory: Callable | None = None,
convert_mapping: bool = False,
fallback: bool = False,
**kwargs: Any,
) -> None:
super().__init__(default_factory)
self.merge(*args, **kwargs)
self.setattr("convert_mapping", convert_mapping)
self.setattr("fallback", fallback)

def all_keys(self) -> Generator:
r"""
Expand Down Expand Up @@ -289,7 +296,7 @@ def apply_(self, func: Callable, *args: Any, **kwargs: Any) -> NestedDict:
apply_(self, func, *args, **kwargs)
return self

def get(self, name: Any, default: Any = None) -> Any:
def get(self, name: Any, default: Any = None, fallback: bool | None = None) -> Any:
r"""
Get value from `NestedDict`.
Expand Down Expand Up @@ -320,6 +327,7 @@ def get(self, name: Any, default: Any = None) -> Any:
1013
>>> d.get('f', 2)
2
>>> d.get('a.b', None)
>>> d.f
NestedDict(<class 'chanfig.nested_dict.NestedDict'>, )
>>> del d.f
Expand All @@ -340,12 +348,24 @@ def get(self, name: Any, default: Any = None) -> Any:
"""

delimiter = self.getattr("delimiter", ".")
if fallback is None:
fallback = self.getattr("fallback", False)
fallback_name = name.split(delimiter)[-1] if isinstance(name, str) else name
fallback_values = []
try:
while isinstance(name, str) and delimiter in name:
if fallback and fallback_name in self:
fallback_values.append(self.get(fallback_name))
name, rest = name.split(delimiter, 1)
self, name = self[name], rest # pylint: disable=W0642
except (AttributeError, TypeError):
except (KeyError, AttributeError, TypeError):
if fallback and fallback_values:
return fallback_values[-1]
if default is not Null:
return default
raise KeyError(name) from None
if (fallback and fallback_values) and (not isinstance(self, Iterable) or name not in self):
return fallback_values[-1]
# if value is a python dict
if not isinstance(self, NestedDict):
if name not in self and default is not Null:
Expand Down Expand Up @@ -405,9 +425,9 @@ def set( # pylint: disable=W0221
# pylint: disable=W0642

full_name = name
delimiter = self.getattr("delimiter", ".")
if convert_mapping is None:
convert_mapping = self.getattr("convert_mapping", False)
delimiter = self.getattr("delimiter", ".")
default_factory = self.getattr("default_factory", self.empty)
try:
while isinstance(name, str) and delimiter in name:
Expand Down
4 changes: 2 additions & 2 deletions chanfig/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class Registry(NestedDict): # type: ignore

override: bool = False

def __init__(self, override: bool = False):
super().__init__()
def __init__(self, override: bool = False, fallback: bool = False):
super().__init__(fallback=fallback)
self.setattr("override", override)

def register(self, component: Any = None, name: Any | None = None) -> Callable:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def test_dropnull(self):
}
assert d.dropnull().dict() == {"h": {"j": 1}}

def test_fallback(self):
d = NestedDict({"n.d": 0.5, "n.a.d": 0.1, "n.b.l": 6})
assert d.get("n.a.d", fallback=True) == 0.1
assert d.get("n.b.d", fallback=True) == 0.5


class ConfigDict(NestedDict):
def __init__(self):
Expand Down

0 comments on commit 53fd169

Please sign in to comment.