diff --git a/README.md b/README.md index a3bd254..eaa3b8a 100644 --- a/README.md +++ b/README.md @@ -347,7 +347,7 @@ class Test_A(EntryPoint): 我们可以使用字段`default_config_file_paths`指定从固定的几个路径中读取配置文件,配置文件支持`json`和`yaml`两种格式. 我们也可以通过字段`config_file_only_get_need`定义从配置文件中读取配置的行为(默认为`True`), -当置为`True`时我们只会在配置文件中读取schema中定义的字段,否则则会加载全部字段. + 当置为`True`时我们只会在配置文件中读取schema中定义的字段,否则则会加载全部字段. 也可以通过设置`load_all_config_file = True`来按设定顺序读取全部预设的配置文件位置 @@ -369,9 +369,45 @@ class Test_A(EntryPoint): ##### 指定特定命名的配置文件的解析方式 -可以使用`@regist_config_file_parser(config_file_name)`来注册如何解析特定命名的配置文件. +可以使用`@regist_config_file_parser(config_file_name)`来注册如何解析特定命名的配置文件.这一特性可以更好的定制化配置文件的读取 ```python +class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" + ] +root = Test_AC() + +@root.regist_config_file_parser("test_other_config2.json") +def _1(p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} + +``` + +如果想在定义子类时固定好,也可以定义`_config_file_parser_map:Dict[str,Callable[[Path], Dict[str, Any]]]` + +```python +def test_other_config2_parser( p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} +class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" + ] + _config_file_parser_map = { + "test_other_config2.json": test_other_config2_parser + } + +root = Test_AC() ``` diff --git a/schema_entry/entrypoint.py b/schema_entry/entrypoint.py index b3c9544..99dcdcb 100644 --- a/schema_entry/entrypoint.py +++ b/schema_entry/entrypoint.py @@ -18,7 +18,7 @@ import functools from copy import deepcopy from pathlib import Path -from typing import Callable, Sequence, Dict, List, Any, Tuple +from typing import Callable, Sequence, Dict, List, Any, Tuple, Optional from jsonschema import validate from yaml import load as yaml_load @@ -43,7 +43,7 @@ class EntryPoint(EntryPointABC): parse_env = True argparse_check_required = False - argparse_noflag = None + argparse_noflag: Optional[str] = None _config_file_parser_map: Dict[str, Callable[[Path], Dict[str, Any]]] = {} def _check_schema(self) -> None: @@ -267,6 +267,7 @@ def parse_configfile_args(self) -> Dict[str, Any]: if p.is_file(): parfunc = self._config_file_parser_map.get(p.name) if parfunc: + print("&&&&&&") return parfunc(p) if p.suffix == ".json": return self.parse_json_configfile_args(p) @@ -282,12 +283,17 @@ def parse_configfile_args(self) -> Dict[str, Any]: for p_str in self.default_config_file_paths: p = Path(p_str) if p.is_file(): - if p.suffix == ".json": - result.update(self.parse_json_configfile_args(p)) - elif p.suffix == ".yml": - result.update(self.parse_yaml_configfile_args(p)) + parfunc = self._config_file_parser_map.get(p.name) + if parfunc: + print("&&&&&&@@@") + result.update(parfunc(p)) else: - warnings.warn(f"跳过不支持的配置格式的文件{str(p)}") + if p.suffix == ".json": + result.update(self.parse_json_configfile_args(p)) + elif p.suffix == ".yml": + result.update(self.parse_yaml_configfile_args(p)) + else: + warnings.warn(f"跳过不支持的配置格式的文件{str(p)}") return result def validat_config(self) -> bool: diff --git a/test_other_config2.json b/test_other_config2.json new file mode 100644 index 0000000..42cf14f --- /dev/null +++ b/test_other_config2.json @@ -0,0 +1,5 @@ +{ + "A": 1, + "D": 43, + "C": 13 +} \ No newline at end of file diff --git a/tests/test_entrypoint.py b/tests/test_entrypoint.py index 39c33d0..73eb2b3 100644 --- a/tests/test_entrypoint.py +++ b/tests/test_entrypoint.py @@ -1,6 +1,8 @@ import os +import json import unittest from pathlib import Path +from typing import Dict, Any import jsonschema.exceptions from schema_entry.entrypoint import EntryPoint @@ -52,7 +54,7 @@ class Test_A(EntryPoint): root = Test_A() @root.as_main - def _(a_a: float) -> None: + def _(**kwargs: Any) -> None: pass root([]) @@ -72,12 +74,12 @@ class Test_A(EntryPoint): }, "required": ["a_a"] } - def do_main(self): - assert self.config["a_a"]==33.3 + + def do_main(self) -> None: + assert self.config["a_a"] == 33.3 root = Test_A() root([]) - def test_default_subcmd_usage(self) -> None: class A(EntryPoint): pass @@ -262,6 +264,59 @@ def _(a: int, b: int, c: int, d: int) -> None: root([]) + def test_load_configfile_with_custom_parser(self) -> None: + class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" + ] + root = Test_AC() + + @root.regist_config_file_parser("test_other_config2.json") + def _1(p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} + + @root.as_main + def _2(a: int, b: int, c: int, d: int) -> None: + assert a == 1 + assert b == 2 + assert c == 13 + assert d == 43 + + root([]) + + def test_load_configfile_with_custom_parser_in_class(self) -> None: + def test_other_config2_parser( p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} + class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" + ] + _config_file_parser_map = { + "test_other_config2.json": test_other_config2_parser + } + + + root = Test_AC() + + @root.as_main + def _2(a: int, b: int, c: int, d: int) -> None: + assert a == 1 + assert b == 2 + assert c == 13 + assert d == 43 + + root([]) + def test_load_ENV_config(self) -> None: class Test_A(EntryPoint): env_prefix = "app"