Skip to content

Commit

Permalink
fix save to/load from IOBase
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Jun 29, 2023
1 parent 525d5b9 commit 8749d67
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
6 changes: 5 additions & 1 deletion chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from contextlib import contextmanager
from copy import copy, deepcopy
from functools import wraps
from io import IOBase
from io import BytesIO, IOBase, StringIO
from json import dumps as json_dumps
from json import loads as json_loads
from os import PathLike
Expand Down Expand Up @@ -895,6 +895,8 @@ def from_json(cls, file: File, *args, **kwargs) -> FlatDict:
"""

with cls.open(file) as fp: # pylint: disable=C0103
if isinstance(fp, (StringIO, BytesIO)):
return cls.from_jsons(fp.getvalue(), *args, **kwargs) # type: ignore
return cls.from_jsons(fp.read(), *args, **kwargs)

def jsons(self, *args, **kwargs) -> str:
Expand Down Expand Up @@ -965,6 +967,8 @@ def from_yaml(cls, file: File, *args, **kwargs) -> FlatDict:
"""

with cls.open(file) as fp: # pylint: disable=C0103
if isinstance(fp, (StringIO, BytesIO)):
return cls.from_yamls(fp.getvalue(), *args, **kwargs) # type: ignore
return cls.from_yamls(fp.read(), *args, **kwargs)

def yamls(self, *args, **kwargs) -> str:
Expand Down
5 changes: 4 additions & 1 deletion chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def set( # pylint: disable=W0221
default_factory = self.getattr("default_factory", self.empty_like)
except (AttributeError, TypeError):
raise KeyError(name) from None

if convert_mapping and isinstance(value, Mapping):
value = default_factory(value)
if isinstance(self, NestedDict):
Expand Down Expand Up @@ -577,7 +578,9 @@ def merge(self, *args, **kwargs) -> NestedDict:

@wraps(self.merge)
def merge(this: NestedDict, that: Iterable) -> Mapping:
if isinstance(that, Mapping):
if isinstance(that, NestedDict):
that = that.all_items()
elif isinstance(that, Mapping):
that = that.items()
for key, value in that:
if key in this and isinstance(this[key], Mapping):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import copy, deepcopy
from functools import partial
from io import StringIO

from chanfig import Config, Variable

Expand Down Expand Up @@ -79,8 +80,9 @@ def test_fstring(self):
def test_load(self):
self.config.name = "Test"
self.config.datasets.a.num_classes = 12
self.config.dump("tests/test_config.json")
self.config = self.config.load("tests/test_config.json")
buffer = StringIO()
self.config.dump(buffer, method="json")
assert self.config == Config.load(buffer, method="json")
assert self.config.name == "Test"
assert self.config.network.name == "ResNet18"
assert self.config.network.num_classes == 12
Expand Down

0 comments on commit 8749d67

Please sign in to comment.