Skip to content

Commit

Permalink
fix setdefault
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Sep 21, 2023
1 parent 74cabb3 commit 8c2ad7f
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,84 @@ def pop(self, name: Any, default: Any = Null) -> Any:
raise KeyError(name)
return super().pop(name)

def setdefault(
self,
name: Any,
value: Any,
convert_mapping: bool | None = None,
) -> Any:
r"""
Set default value for `NestedDict`.
Args:
name:
value:
convert_mapping: Whether to convert `Mapping` to `NestedDict`.
Defaults to self.convert_mapping.
Returns:
value: If `NestedDict` does not contain `name`, return `value`.
Examples:
>>> d = NestedDict({"i.d": 1013, "f.n": "chang", "n.a.b.c": 1})
>>> d.setdefault("d.i", 1031)
1031
>>> d.setdefault("i.d", "chang")
1013
>>> d.setdefault("f.n", 1013)
'chang'
>>> d.setdefault("n.a.b.d", 2)
2
"""
# pylint: disable=W0642

full_name = name
delimiter = self.getattr("delimiter", ".")
if convert_mapping is None:
convert_mapping = self.getattr("convert_mapping", False)
default_factory = self.getattr("default_factory", self.empty)
try:
while isinstance(name, str) and delimiter in name:
name, rest = name.split(delimiter, 1)
if name in dir(self) and isinstance(getattr(self.__class__, name), (property, cached_property)):
self, name = getattr(self, name), rest
elif name not in self and isinstance(self, Mapping):
default = (
self.__missing__(name, default_factory()) if hasattr(self, "__missing__") else default_factory()
)
self, name = default, rest
else:
self, name = self[name], rest
if isinstance(self, NestedDict):
default_factory = self.getattr("default_factory", self.empty)
except (AttributeError, TypeError):
raise KeyError(name) from None

if isinstance(self, NestedDict) and name in self:
return super().get(name)
elif isinstance(self, Mapping) and name in self:
dict.__getitem__(self, name)

if (
convert_mapping
and isinstance(value, Mapping)
and not isinstance(value, default_factory if isinstance(default_factory, type) else type(self))
and not isinstance(value, Variable)
):
try:
value = default_factory(**value)
except TypeError:
value = default_factory(value)
if isinstance(self, NestedDict):
super().set(name, value)
elif isinstance(self, Mapping):
dict.__setitem__(self, name, value)
else:
raise ValueError(
f"Cannot set `{full_name}` to `{value}`, as `{delimiter.join(full_name.split(delimiter)[:-1])}={self}`."
)
return value

def validate(self) -> None:
r"""
Validate `NestedDict`.
Expand Down

0 comments on commit 8c2ad7f

Please sign in to comment.