Skip to content

Commit

Permalink
post now removes default_factory in Config
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Jul 15, 2024
1 parent 1f35d63 commit e16dcb2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 34 deletions.
11 changes: 11 additions & 0 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ def post(self) -> Self | None:
self:
Examples:
>>> c = Config()
>>> c.dne
Config(<class 'chanfig.config.Config'>, )
>>> c.post()
Config(
('dne'): Config()
)
>>> c.dne2
Traceback (most recent call last):
AttributeError: 'Config' object has no attribute 'dne2'
>>> class PostConfig(Config):
... def post(self):
... if isinstance(self.data, str):
Expand All @@ -190,6 +200,7 @@ def post(self) -> Self | None:

self.interpolate()
self.validate()
self.apply_(lambda c: c.setattr("default_factory", None) if isinstance(c, Config) else None)
return self

def boot(self) -> Self:
Expand Down
76 changes: 42 additions & 34 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,21 @@ def data(self):


class Test:
config = TestConfig()

def test_value(self):
assert self.config.name == "CHANfiG"
config = TestConfig()
assert config.name == "CHANfiG"

def test_post(self):
self.config.boot()
assert self.config.name == "chanfig"
assert self.config.id == "chanfig_1013"
assert self.config.datasets.a.name == "cifar10"
config = TestConfig()
config.boot()
assert config.name == "chanfig"
assert config.id == "chanfig_1013"
assert config.datasets.a.name == "cifar10"

def test_parse(self):
self.config.parse(
config = TestConfig()
config.parse(
[
"--name",
"Test",
Expand All @@ -98,46 +100,52 @@ def test_parse(self):
"path/to/checkpoint.pth",
]
)
assert self.config.name == "test"
assert self.config.id == "test_1014"
assert self.config.checkpoint == "path/to/checkpoint.pth"
assert self.config.data.name == "cifar10"
assert self.config.datas.a.feature_cols == ["a", "b", "c"]
assert self.config.datas.b.label_cols == ["d", "e", "f"]
assert self.config.data.max_length == 1024
assert config.name == "test"
assert config.id == "test_1014"
assert config.checkpoint == "path/to/checkpoint.pth"
assert config.data.name == "cifar10"
assert config.datas.a.feature_cols == ["a", "b", "c"]
assert config.datas.b.label_cols == ["d", "e", "f"]
assert config.data.max_length == 1024

def test_nested(self):
assert self.config.network.name == "ResNet18"
self.config.network.nested.value = 1
config = TestConfig()
assert config.network.name == "ResNet18"
config.network.nested.value = 1

def test_contains(self):
assert "name" in self.config
assert "seed" in self.config
assert "a.b.c" not in self.config
assert "a.b" not in self.config
config = TestConfig()
assert "name" in config
assert "seed" in config
assert "a.b.c" not in config
assert "a.b" not in config

def test_variable(self):
assert self.config.network.num_classes == 10
self.config.network.num_classes += 1
assert self.config.datasets.a.num_classes == 11
config = TestConfig()
assert config.network.num_classes == 10
config.network.num_classes += 1
assert config.datasets.a.num_classes == 11

def test_fstring(self):
assert f"seed{self.config.seed}" == "seed1014"
config = TestConfig()
assert f"seed{config.seed}" == "seed1013"

def test_load(self):
self.config.name = "Test"
self.config.datasets.a.num_classes = 12
config = TestConfig()
config.name = "Test"
config.datasets.a.num_classes = 12
buffer = StringIO()
self.config.dump(buffer, method="json")
assert self.config == Config.load(buffer, method="json")
assert self.config.name == "Test"
assert self.config.network.name == "ResNet18"
assert self.config.network.num_classes == 12
assert self.config.datasets.a.num_classes == 12
config.dump(buffer, method="json")
assert config == Config.load(buffer, method="json")
assert config.name == "Test"
assert config.network.name == "ResNet18"
assert config.network.num_classes == 12
assert config.datasets.a.num_classes == 12

def test_copy(self):
assert self.config.copy() == copy(self.config)
assert self.config.deepcopy() == deepcopy(self.config)
config = TestConfig()
assert config.copy() == copy(config)
assert config.deepcopy() == deepcopy(config)


class Ancestor(Config):
Expand Down

0 comments on commit e16dcb2

Please sign in to comment.