From 8c2ad7f19ed45190b01086889c97d8545145a1c0 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 21 Sep 2023 10:32:20 +0000 Subject: [PATCH] fix setdefault Signed-off-by: Zhiyuan Chen --- chanfig/nested_dict.py | 78 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index c0459fcc..c613feba 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -557,6 +557,84 @@ def pop(self, name: Any, default: Any = Null) -> Any: raise KeyError(name) return super().pop(name) + def setdefault( + self, + name: Any, + value: Any, + convert_mapping: bool | None = None, + ) -> Any: + r""" + Set default value for `NestedDict`. + + Args: + name: + value: + convert_mapping: Whether to convert `Mapping` to `NestedDict`. + Defaults to self.convert_mapping. + + Returns: + value: If `NestedDict` does not contain `name`, return `value`. + + Examples: + >>> d = NestedDict({"i.d": 1013, "f.n": "chang", "n.a.b.c": 1}) + >>> d.setdefault("d.i", 1031) + 1031 + >>> d.setdefault("i.d", "chang") + 1013 + >>> d.setdefault("f.n", 1013) + 'chang' + >>> d.setdefault("n.a.b.d", 2) + 2 + """ + # pylint: disable=W0642 + + full_name = name + delimiter = self.getattr("delimiter", ".") + if convert_mapping is None: + convert_mapping = self.getattr("convert_mapping", False) + default_factory = self.getattr("default_factory", self.empty) + try: + while isinstance(name, str) and delimiter in name: + name, rest = name.split(delimiter, 1) + if name in dir(self) and isinstance(getattr(self.__class__, name), (property, cached_property)): + self, name = getattr(self, name), rest + elif name not in self and isinstance(self, Mapping): + default = ( + self.__missing__(name, default_factory()) if hasattr(self, "__missing__") else default_factory() + ) + self, name = default, rest + else: + self, name = self[name], rest + if isinstance(self, NestedDict): + default_factory = self.getattr("default_factory", self.empty) + except (AttributeError, TypeError): + raise KeyError(name) from None + + if isinstance(self, NestedDict) and name in self: + return super().get(name) + elif isinstance(self, Mapping) and name in self: + dict.__getitem__(self, name) + + if ( + convert_mapping + and isinstance(value, Mapping) + and not isinstance(value, default_factory if isinstance(default_factory, type) else type(self)) + and not isinstance(value, Variable) + ): + try: + value = default_factory(**value) + except TypeError: + value = default_factory(value) + if isinstance(self, NestedDict): + super().set(name, value) + elif isinstance(self, Mapping): + dict.__setitem__(self, name, value) + else: + raise ValueError( + f"Cannot set `{full_name}` to `{value}`, as `{delimiter.join(full_name.split(delimiter)[:-1])}={self}`." + ) + return value + def validate(self) -> None: r""" Validate `NestedDict`.