From 04ee02bc612a5ffe64cf4848455041d297c05afc Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sun, 26 May 2024 22:21:33 +0800 Subject: [PATCH] support init from Namespace / file Signed-off-by: Zhiyuan Chen --- chanfig/flat_dict.py | 11 +++++++++++ chanfig/parser.py | 12 +++++++++--- tests/test_flat_dict.py | 13 +++++++++++++ tests/test_parser.py | 37 ++++++++++++++++++++++++++++++++++++- 4 files changed, 69 insertions(+), 4 deletions(-) diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index 7835e068..efde7aa4 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -19,6 +19,7 @@ from __future__ import annotations +from argparse import Namespace from collections.abc import Callable, Generator, Iterable, Mapping, MutableMapping, Sequence, Set from contextlib import contextmanager, suppress from copy import copy, deepcopy @@ -182,6 +183,16 @@ class FlatDict(dict, metaclass=Dict): indent: int = 2 + def __init__(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 1: + arg = args[0] + if isinstance(arg, (PathLike, str, bytes)): + arg = self.load(arg) + elif isinstance(arg, (Namespace,)): + arg = vars(arg) + args = (arg,) + super().__init__(*args, **kwargs) + def __post_init__(self, *args, **kwargs) -> None: pass diff --git a/chanfig/parser.py b/chanfig/parser.py index 2e5a2088..44bbe17a 100644 --- a/chanfig/parser.py +++ b/chanfig/parser.py @@ -24,10 +24,16 @@ from contextlib import suppress from dataclasses import Field from inspect import isclass -from types import NoneType, UnionType -from typing import TYPE_CHECKING, Any, _UnionGenericAlias, get_args # type: ignore[attr-defined] +from typing import TYPE_CHECKING, Any from warnings import warn +from typing_extensions import _should_collect_from_parameters, get_args # type: ignore[attr-defined] + +try: + from types import NoneType +except ImportError: + NoneType = type(None) # type: ignore[misc, assignment] + from .nested_dict import NestedDict from .utils import Null, get_annotations, parse_bool from .variable import Variable @@ -294,7 +300,7 @@ def add_config_argument(self, key, value: Any | None = None, dtype: type | None dtype = value.type elif value is not None: dtype = type(value) - if isinstance(dtype, (UnionType, _UnionGenericAlias)): + if _should_collect_from_parameters(dtype): args = get_args(dtype) if len(args) == 2 and NoneType in args: dtype = args[0] if args[0] is not NoneType else args[1] diff --git a/tests/test_flat_dict.py b/tests/test_flat_dict.py index e0e9c84a..55eb8b95 100644 --- a/tests/test_flat_dict.py +++ b/tests/test_flat_dict.py @@ -15,6 +15,7 @@ from __future__ import annotations +from argparse import ArgumentParser from copy import copy, deepcopy from typing import Dict, List, Optional, Tuple, Union @@ -97,3 +98,15 @@ def test_validate(self): ConfigDict(optional_str=None) with raises(TypeError): ConfigDict(optional_str=1) + + def test_construct_file(self): + d = FlatDict("tests/test.json") + assert d == FlatDict({"a": 1, "b": 2, "c": 3}) + + def test_construct_namespace(self): + parser = ArgumentParser() + parser.add_argument("--name", type=str) + parser.add_argument("--seed", type=int) + d = FlatDict(parser.parse_args(["--name", "chang", "--seed", "1013"])) + assert d.name == "chang" + assert d.seed == 1013 diff --git a/tests/test_parser.py b/tests/test_parser.py index 9048f597..2a72888f 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -17,8 +17,11 @@ from __future__ import annotations +import sys from typing import List, Optional +import pytest + from chanfig import Config @@ -32,10 +35,18 @@ class TestConfig(Config): f: bool false: bool n: bool - no: bool | None + no: Optional[bool] not_recognized: List[bool] +class TestConfigPEP604(Config): + __test__ = False + + true: bool | None + false: bool | None + not_recognized: list[bool] + + class Test: def test_parse_bool(self): @@ -86,3 +97,27 @@ def test_parse_bool(self): ) assert config.t and config.true and config.y and config.yes assert not config.f and not config.false and not config.n and not config.no + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 is available in Python 3.10+") + def test_parse_pep604(self): + config = TestConfigPEP604() + config.parse( + [ + "--true", + "true", + "--false", + "false", + ] + ) + assert config.true and not config.false + + config = TestConfigPEP604() + config.parse( + [ + "--true", + "True", + "--false", + "False", + ] + ) + assert config.true and not config.false