Skip to content

Commit

Permalink
support init from Namespace / file
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed May 28, 2024
1 parent 5bf8144 commit 04ee02b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 4 deletions.
11 changes: 11 additions & 0 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

from argparse import Namespace
from collections.abc import Callable, Generator, Iterable, Mapping, MutableMapping, Sequence, Set
from contextlib import contextmanager, suppress
from copy import copy, deepcopy
Expand Down Expand Up @@ -182,6 +183,16 @@ class FlatDict(dict, metaclass=Dict):

indent: int = 2

def __init__(self, *args: Any, **kwargs: Any) -> None:
if len(args) == 1:
arg = args[0]
if isinstance(arg, (PathLike, str, bytes)):
arg = self.load(arg)
elif isinstance(arg, (Namespace,)):
arg = vars(arg)
args = (arg,)
super().__init__(*args, **kwargs)

def __post_init__(self, *args, **kwargs) -> None:
pass

Expand Down
12 changes: 9 additions & 3 deletions chanfig/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@
from contextlib import suppress
from dataclasses import Field
from inspect import isclass
from types import NoneType, UnionType
from typing import TYPE_CHECKING, Any, _UnionGenericAlias, get_args # type: ignore[attr-defined]
from typing import TYPE_CHECKING, Any
from warnings import warn

from typing_extensions import _should_collect_from_parameters, get_args # type: ignore[attr-defined]

try:
from types import NoneType
except ImportError:
NoneType = type(None) # type: ignore[misc, assignment]

from .nested_dict import NestedDict
from .utils import Null, get_annotations, parse_bool
from .variable import Variable
Expand Down Expand Up @@ -294,7 +300,7 @@ def add_config_argument(self, key, value: Any | None = None, dtype: type | None
dtype = value.type
elif value is not None:
dtype = type(value)
if isinstance(dtype, (UnionType, _UnionGenericAlias)):
if _should_collect_from_parameters(dtype):
args = get_args(dtype)
if len(args) == 2 and NoneType in args:
dtype = args[0] if args[0] is not NoneType else args[1]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

from argparse import ArgumentParser
from copy import copy, deepcopy
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -97,3 +98,15 @@ def test_validate(self):
ConfigDict(optional_str=None)
with raises(TypeError):
ConfigDict(optional_str=1)

def test_construct_file(self):
d = FlatDict("tests/test.json")
assert d == FlatDict({"a": 1, "b": 2, "c": 3})

def test_construct_namespace(self):
parser = ArgumentParser()
parser.add_argument("--name", type=str)
parser.add_argument("--seed", type=int)
d = FlatDict(parser.parse_args(["--name", "chang", "--seed", "1013"]))
assert d.name == "chang"
assert d.seed == 1013
37 changes: 36 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

from __future__ import annotations

import sys
from typing import List, Optional

import pytest

from chanfig import Config


Expand All @@ -32,10 +35,18 @@ class TestConfig(Config):
f: bool
false: bool
n: bool
no: bool | None
no: Optional[bool]
not_recognized: List[bool]


class TestConfigPEP604(Config):
__test__ = False

true: bool | None
false: bool | None
not_recognized: list[bool]


class Test:

def test_parse_bool(self):
Expand Down Expand Up @@ -86,3 +97,27 @@ def test_parse_bool(self):
)
assert config.t and config.true and config.y and config.yes
assert not config.f and not config.false and not config.n and not config.no

@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 is available in Python 3.10+")
def test_parse_pep604(self):
config = TestConfigPEP604()
config.parse(
[
"--true",
"true",
"--false",
"false",
]
)
assert config.true and not config.false

config = TestConfigPEP604()
config.parse(
[
"--true",
"True",
"--false",
"False",
]
)
assert config.true and not config.false

0 comments on commit 04ee02b

Please sign in to comment.