From 8749d674d9352ac17ececd728165a6161959b36d Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 29 Jun 2023 12:16:35 +0800 Subject: [PATCH] fix save to/load from IOBase --- chanfig/flat_dict.py | 6 +++++- chanfig/nested_dict.py | 5 ++++- tests/test_config.py | 6 ++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index d3ec76ff..55d94867 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -21,7 +21,7 @@ from contextlib import contextmanager from copy import copy, deepcopy from functools import wraps -from io import IOBase +from io import BytesIO, IOBase, StringIO from json import dumps as json_dumps from json import loads as json_loads from os import PathLike @@ -895,6 +895,8 @@ def from_json(cls, file: File, *args, **kwargs) -> FlatDict: """ with cls.open(file) as fp: # pylint: disable=C0103 + if isinstance(fp, (StringIO, BytesIO)): + return cls.from_jsons(fp.getvalue(), *args, **kwargs) # type: ignore return cls.from_jsons(fp.read(), *args, **kwargs) def jsons(self, *args, **kwargs) -> str: @@ -965,6 +967,8 @@ def from_yaml(cls, file: File, *args, **kwargs) -> FlatDict: """ with cls.open(file) as fp: # pylint: disable=C0103 + if isinstance(fp, (StringIO, BytesIO)): + return cls.from_yamls(fp.getvalue(), *args, **kwargs) # type: ignore return cls.from_yamls(fp.read(), *args, **kwargs) def yamls(self, *args, **kwargs) -> str: diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index b317b4df..25262395 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -418,6 +418,7 @@ def set( # pylint: disable=W0221 default_factory = self.getattr("default_factory", self.empty_like) except (AttributeError, TypeError): raise KeyError(name) from None + if convert_mapping and isinstance(value, Mapping): value = default_factory(value) if isinstance(self, NestedDict): @@ -577,7 +578,9 @@ def merge(self, *args, **kwargs) -> NestedDict: @wraps(self.merge) def merge(this: NestedDict, that: Iterable) -> Mapping: - if isinstance(that, Mapping): + if isinstance(that, NestedDict): + that = that.all_items() + elif isinstance(that, Mapping): that = that.items() for key, value in that: if key in this and isinstance(this[key], Mapping): diff --git a/tests/test_config.py b/tests/test_config.py index 3688e94e..1296523a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,6 @@ from copy import copy, deepcopy from functools import partial +from io import StringIO from chanfig import Config, Variable @@ -79,8 +80,9 @@ def test_fstring(self): def test_load(self): self.config.name = "Test" self.config.datasets.a.num_classes = 12 - self.config.dump("tests/test_config.json") - self.config = self.config.load("tests/test_config.json") + buffer = StringIO() + self.config.dump(buffer, method="json") + assert self.config == Config.load(buffer, method="json") assert self.config.name == "Test" assert self.config.network.name == "ResNet18" assert self.config.network.num_classes == 12