Skip to content

Commit

Permalink
add test for change property in cli
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Jun 29, 2023
1 parent b3771aa commit caabc69
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
datas
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ repos:
rev: v2.2.1
hooks:
- id: codespell
args: [--ignore-words=.codespell-whitelist.txt]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.9-for-vscode
hooks:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
"num_classes": 12
}
},
"datas": {
"default": {
"name": "cifar10",
"root": "datasets"
}
},
"network": {
"name": "ResNet18",
"num_classes": 12,
Expand Down
8 changes: 7 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs):
self.seed = Variable(1013)
data_factory = partial(DataConfig, name="CIFAR10")
self.datasets = Config(default_factory=data_factory)
self.datas = Config(default_factory=data_factory)
self.datasets.a.num_classes = num_classes
self.datasets.b.num_classes = num_classes
self.network.name = "ResNet18"
Expand All @@ -34,6 +35,10 @@ def post(self):
self.name = self.name.lower()
self.id = f"{self.name}_{self.seed}"

@property
def data(self):
return next(iter(self.datas.values())) if self.datas else self.datas["default"]


class Test:
config = TestConfig()
Expand All @@ -48,9 +53,10 @@ def test_post(self):
assert self.config.datasets.a.name == "cifar10"

def test_parse(self):
self.config.parse(["--name", "Test", "--seed", "1014"])
self.config.parse(["--name", "Test", "--seed", "1014", "--data.root", "datasets"])
assert self.config.name == "test"
assert self.config.id == "test_1014"
assert self.config.data.name == "cifar10"

def test_nested(self):
assert self.config.network.name == "ResNet18"
Expand Down

0 comments on commit caabc69

Please sign in to comment.