Skip to content

Commit

Permalink
fix value in config file overrided by default
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Feb 8, 2023
1 parent e36b9d2 commit 2372f9d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
11 changes: 8 additions & 3 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def hasattr(self, name: str) -> bool:
except AttributeError:
return False

def __missing__(self, name: str, default: Any = Null) -> Any:
def __missing__(self, name: str, default: Any = Null) -> Any: # pylint: disable=R1710
if name == "_ipython_canary_method_should_not_exist_":
return
if default is Null:
Expand Down Expand Up @@ -867,7 +867,9 @@ def from_jsons(cls, string: str, *args, **kwargs) -> FlatDict:
```
"""

return cls(**json_loads(string, *args, **kwargs))
config = cls()
config.update(json_loads(string, *args, **kwargs))
return config

def yaml(self, file: File, *args, **kwargs) -> None:
r"""
Expand Down Expand Up @@ -950,7 +952,10 @@ def from_yamls(cls, string: str, *args, **kwargs) -> FlatDict:

if "Loader" not in kwargs:
kwargs["Loader"] = YamlLoader
return cls(**yaml_load(string, *args, **kwargs))

config = cls()
config.update(yaml_load(string, *args, **kwargs))
return config

@staticmethod
@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion chanfig/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __contains__(self, name):
def __iter__(self):
return self

def __next__(self):
def __next__(self): # pylint: disable=R0201
raise StopIteration


Expand Down
12 changes: 11 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class TestConfig(Config):
__test__ = False

def __init__(self, *args, **kwargs):
super().__init__()
super().__init__(*args, **kwargs)
num_classes = Variable(10)
self.name = "CHANfiG"
self.seed = Variable(1013)
Expand All @@ -31,3 +31,13 @@ def test_variable(self):

def test_fstring(self):
assert f"seed{self.config.seed}" == "seed1013"

def test_load(self):
self.config.name = "Test"
self.config.dataset.num_classes = 12
self.config.dump("tests/test_config.json")
self.config = self.config.load("tests/test_config.json")
assert self.config.name == "Test"
assert self.config.network.name == "ResNet18"
assert self.config.network.num_classes == 12
assert self.config.dataset.num_classes == 12

0 comments on commit 2372f9d

Please sign in to comment.