diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt new file mode 100644 index 00000000..1449575d --- /dev/null +++ b/.codespell-whitelist.txt @@ -0,0 +1 @@ +datas \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26c13585..6a28b1f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,6 +40,7 @@ repos: rev: v2.2.1 hooks: - id: codespell + args: [--ignore-words=.codespell-whitelist.txt] - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.0.0-alpha.9-for-vscode hooks: diff --git a/tests/test_config.json b/tests/test_config.json index 2745f1a8..1953c804 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -11,6 +11,12 @@ "num_classes": 12 } }, + "datas": { + "default": { + "name": "cifar10", + "root": "datasets" + } + }, "network": { "name": "ResNet18", "num_classes": 12, diff --git a/tests/test_config.py b/tests/test_config.py index 19c32075..1735147f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs): self.seed = Variable(1013) data_factory = partial(DataConfig, name="CIFAR10") self.datasets = Config(default_factory=data_factory) + self.datas = Config(default_factory=data_factory) self.datasets.a.num_classes = num_classes self.datasets.b.num_classes = num_classes self.network.name = "ResNet18" @@ -34,6 +35,10 @@ def post(self): self.name = self.name.lower() self.id = f"{self.name}_{self.seed}" + @property + def data(self): + return next(iter(self.datas.values())) if self.datas else self.datas["default"] + class Test: config = TestConfig() @@ -48,9 +53,10 @@ def test_post(self): assert self.config.datasets.a.name == "cifar10" def test_parse(self): - self.config.parse(["--name", "Test", "--seed", "1014"]) + self.config.parse(["--name", "Test", "--seed", "1014", "--data.root", "datasets"]) assert self.config.name == "test" assert self.config.id == "test_1014" + assert self.config.data.name == "cifar10" def test_nested(self): assert self.config.network.name == "ResNet18"