diff --git a/Changelog.md b/Changelog.md index 79841cc..95def24 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,4 +1,10 @@ -# 0.08 +# 0.0.9 + +## 新特性 + ++ 可以使用`@regist_config_file_parser(config_file_name)`来注册如何解析特定命名的配置文件 + +# 0.0.8 ## 新特性 diff --git a/README.md b/README.md index f4a8c60..eaa3b8a 100644 --- a/README.md +++ b/README.md @@ -347,7 +347,7 @@ class Test_A(EntryPoint): 我们可以使用字段`default_config_file_paths`指定从固定的几个路径中读取配置文件,配置文件支持`json`和`yaml`两种格式. 我们也可以通过字段`config_file_only_get_need`定义从配置文件中读取配置的行为(默认为`True`), -当置为`True`时我们只会在配置文件中读取schema中定义的字段,否则则会加载全部字段. + 当置为`True`时我们只会在配置文件中读取schema中定义的字段,否则则会加载全部字段. 也可以通过设置`load_all_config_file = True`来按设定顺序读取全部预设的配置文件位置 @@ -362,8 +362,53 @@ class Test_A(EntryPoint): default_config_file_paths = [ "/test_config.json", str(Path.home().joinpath(".test_config.json")), - "./test_config.json" + "./test_config.json", + "./test_config_other.json" + ] +``` + +##### 指定特定命名的配置文件的解析方式 + +可以使用`@regist_config_file_parser(config_file_name)`来注册如何解析特定命名的配置文件.这一特性可以更好的定制化配置文件的读取 + +```python +class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" + ] +root = Test_AC() + +@root.regist_config_file_parser("test_other_config2.json") +def _1(p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} + +``` + +如果想在定义子类时固定好,也可以定义`_config_file_parser_map:Dict[str,Callable[[Path], Dict[str, Any]]]` + +```python +def test_other_config2_parser( p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} +class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" ] + _config_file_parser_map = { + "test_other_config2.json": test_other_config2_parser + } + +root = Test_AC() + ``` #### 从环境变量中读取配置参数 @@ -435,6 +480,29 @@ def main(a,b): ``` +另一种指定入口函数的方法是重写子类的`do_main(self)->None`方法 + +```python +class Test_A(EntryPoint): + argparse_noflag = "a" + argparse_check_required=True + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "a": { + "type": "number" + }, + "b": { + "type": "number" + } + }, + "required": ["a","b"] + } + def do_main(self)->None: + print(self.config) +``` + #### 直接从节点对象中获取配置 节点对象的`config`属性会在每次调用时copy一份当前的配置值,config是不可写的. diff --git a/schema_entry/entrypoint.py b/schema_entry/entrypoint.py index 4ddcfb1..99dcdcb 100644 --- a/schema_entry/entrypoint.py +++ b/schema_entry/entrypoint.py @@ -18,7 +18,7 @@ import functools from copy import deepcopy from pathlib import Path -from typing import Callable, Sequence, Dict, List, Any, Tuple +from typing import Callable, Sequence, Dict, List, Any, Tuple, Optional from jsonschema import validate from yaml import load as yaml_load @@ -43,7 +43,8 @@ class EntryPoint(EntryPointABC): parse_env = True argparse_check_required = False - argparse_noflag = None + argparse_noflag: Optional[str] = None + _config_file_parser_map: Dict[str, Callable[[Path], Dict[str, Any]]] = {} def _check_schema(self) -> None: if self.schema is not None: @@ -248,6 +249,15 @@ def parse_yaml_configfile_args(self, p: Path) -> Dict[str, Any]: return res return result + def regist_config_file_parser(self, file_name: str) -> Callable[[Callable[[Path], Dict[str, Any]]], Callable[[Path], Dict[str, Any]]]: + def decorate(func: Callable[[Path], Dict[str, Any]]) -> Callable[[Path], Dict[str, Any]]: + @functools.wraps(func) + def wrap(p: Path) -> Dict[str, Any]: + return func(p) + self._config_file_parser_map[file_name] = func + return wrap + return decorate + def parse_configfile_args(self) -> Dict[str, Any]: if not self.default_config_file_paths: return {} @@ -255,6 +265,10 @@ def parse_configfile_args(self) -> Dict[str, Any]: for p_str in self.default_config_file_paths: p = Path(p_str) if p.is_file(): + parfunc = self._config_file_parser_map.get(p.name) + if parfunc: + print("&&&&&&") + return parfunc(p) if p.suffix == ".json": return self.parse_json_configfile_args(p) elif p.suffix == ".yml": @@ -269,12 +283,17 @@ def parse_configfile_args(self) -> Dict[str, Any]: for p_str in self.default_config_file_paths: p = Path(p_str) if p.is_file(): - if p.suffix == ".json": - result.update(self.parse_json_configfile_args(p)) - elif p.suffix == ".yml": - result.update(self.parse_yaml_configfile_args(p)) + parfunc = self._config_file_parser_map.get(p.name) + if parfunc: + print("&&&&&&@@@") + result.update(parfunc(p)) else: - warnings.warn(f"跳过不支持的配置格式的文件{str(p)}") + if p.suffix == ".json": + result.update(self.parse_json_configfile_args(p)) + elif p.suffix == ".yml": + result.update(self.parse_yaml_configfile_args(p)) + else: + warnings.warn(f"跳过不支持的配置格式的文件{str(p)}") return result def validat_config(self) -> bool: diff --git a/schema_entry/entrypoint_base.py b/schema_entry/entrypoint_base.py index aeb3c64..4806b54 100644 --- a/schema_entry/entrypoint_base.py +++ b/schema_entry/entrypoint_base.py @@ -1,6 +1,7 @@ """入口类的抽象基类.""" import abc import argparse +from pathlib import Path from typing import Callable, Sequence, Dict, Any, Optional, List, Union, Tuple @@ -40,6 +41,7 @@ class EntryPointABC(abc.ABC): _subcmds: Dict[str, "EntryPointABC"] _main: Optional[Callable[..., None]] + _config_file_parser_map: Dict[str, Callable[[Path], Dict[str, Any]]] _config: Dict[str, Any] @abc.abstractproperty @@ -79,6 +81,19 @@ def regist_sub(self, subcmdclz: type) -> "EntryPointABC": [EntryPointABC]: 注册类的实例 ''' + + @abc.abstractmethod + def regist_config_file_parser(self, file_name: str) -> Callable[[Callable[[Path], Dict[str, Any]]], Callable[[Path], Dict[str, Any]]]: + '''注册特定配置文件名的解析方式. + + Args: + file_name (str): 指定文件名 + + Returns: + Callable[[Callable[[Path], None]], Callable[[Path], None]]: 注册的解析函数 + + ''' + @abc.abstractmethod def as_main(self, func: Callable[..., None]) -> Callable[..., None]: """注册函数在解析参数成功后执行. diff --git a/setup.cfg b/setup.cfg index 5b8c1e2..c4ed7c6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = schema_entry -version = 0.0.8 +version = 0.0.9 url = https://github.com/Python-Tools/schema_entry author = hsz author_email = hsz1273327@gmail.com diff --git a/test_other_config2.json b/test_other_config2.json new file mode 100644 index 0000000..42cf14f --- /dev/null +++ b/test_other_config2.json @@ -0,0 +1,5 @@ +{ + "A": 1, + "D": 43, + "C": 13 +} \ No newline at end of file diff --git a/tests/test_entrypoint.py b/tests/test_entrypoint.py index 379ad5b..73eb2b3 100644 --- a/tests/test_entrypoint.py +++ b/tests/test_entrypoint.py @@ -1,6 +1,8 @@ import os +import json import unittest from pathlib import Path +from typing import Dict, Any import jsonschema.exceptions from schema_entry.entrypoint import EntryPoint @@ -52,13 +54,32 @@ class Test_A(EntryPoint): root = Test_A() @root.as_main - def _(a_a: float) -> None: + def _(**kwargs: Any) -> None: pass root([]) assert root.usage == "test_a [options]" + def test_override_do_main(self) -> None: + class Test_A(EntryPoint): + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "a_a": { + "type": "number", + "default": 33.3 + } + }, + "required": ["a_a"] + } + + def do_main(self) -> None: + assert self.config["a_a"] == 33.3 + root = Test_A() + root([]) + def test_default_subcmd_usage(self) -> None: class A(EntryPoint): pass @@ -243,6 +264,59 @@ def _(a: int, b: int, c: int, d: int) -> None: root([]) + def test_load_configfile_with_custom_parser(self) -> None: + class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" + ] + root = Test_AC() + + @root.regist_config_file_parser("test_other_config2.json") + def _1(p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} + + @root.as_main + def _2(a: int, b: int, c: int, d: int) -> None: + assert a == 1 + assert b == 2 + assert c == 13 + assert d == 43 + + root([]) + + def test_load_configfile_with_custom_parser_in_class(self) -> None: + def test_other_config2_parser( p: Path) -> Dict[str, Any]: + with open(p) as f: + temp = json.load(f) + return {k.lower(): v for k, v in temp.items()} + class Test_AC(EntryPoint): + load_all_config_file = True + default_config_file_paths = [ + "./test_config.json", + "./test_config1.json", + "./test_other_config2.json" + ] + _config_file_parser_map = { + "test_other_config2.json": test_other_config2_parser + } + + + root = Test_AC() + + @root.as_main + def _2(a: int, b: int, c: int, d: int) -> None: + assert a == 1 + assert b == 2 + assert c == 13 + assert d == 43 + + root([]) + def test_load_ENV_config(self) -> None: class Test_A(EntryPoint): env_prefix = "app"