From dc32dad4e2fd193b4401026a3018cb15f4a0d024 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Wed, 21 Aug 2024 16:34:47 +0800 Subject: [PATCH] make all classes "dataclass"es Signed-off-by: Zhiyuan Chen --- chanfig/config.py | 58 ++++----------------------------------- chanfig/configclasses.py | 24 +++++++--------- chanfig/default_dict.py | 14 ++++++---- chanfig/flat_dict.py | 29 +++++++++++++++++++- chanfig/nested_dict.py | 6 ++-- chanfig/registry.py | 6 ++-- demo/config.py | 4 +-- tests/test_config.py | 7 ----- tests/test_configclass.py | 7 +---- 9 files changed, 60 insertions(+), 95 deletions(-) diff --git a/chanfig/config.py b/chanfig/config.py index 9ecb0100..60768d86 100755 --- a/chanfig/config.py +++ b/chanfig/config.py @@ -97,8 +97,8 @@ class Config(NestedDict): {'f': {'n': 'chang'}, 'i': {'d': 1013}} """ - parser: None # ConfigParser, Python 3.7 does not support forward reference - frozen: bool + parser = None # ConfigParser, Python 3.7 does not support forward reference + frozen = False def __init__(self, *args: Any, default_factory: Callable | None = None, **kwargs: Any): if default_factory is None: @@ -106,54 +106,6 @@ def __init__(self, *args: Any, default_factory: Callable | None = None, **kwargs self.setattr("frozen", False) super().__init__(*args, default_factory=default_factory, **kwargs) - def copy_class_attributes(self, recursive: bool = True) -> Self: - r""" - Copy class attributes to instance. - - Args: - recursive: - - Returns: - self: - - Examples: - >>> class Ancestor(Config): - ... a = 1 - >>> class Parent(Ancestor): - ... b = 2 - >>> class Child(Parent): - ... c = 3 - >>> c = Child() - >>> c - Child(, ) - >>> c.copy_class_attributes(recursive=False) - Child(,('c'): 3) - >>> c.copy_class_attributes() # doctest: +SKIP - Child(, - ('a'): 1, - ('b'): 2, - ('c'): 3 - ) - """ - - def copy_cls_attributes(cls: type) -> Mapping: - return { - k: v - for k, v in cls.__dict__.items() - if k not in self - and not k.startswith("__") - and (not (isinstance(v, (property, staticmethod, classmethod)) or callable(v))) - } - - if recursive: - for cls in self.__class__.__mro__: - if cls.__module__.startswith("chanfig"): - break - self.merge(copy_cls_attributes(cls), overwrite=False) - else: - self.merge(copy_cls_attributes(self.__class__), overwrite=False) - return self - def post(self) -> Self | None: r""" Post process of `Config`. @@ -288,7 +240,7 @@ def parse( {'a': 1, 'b': 2, 'c': 3} """ - if not self.hasattr("parser"): + if self.getattr("parser") is None: self.setattr("parser", ConfigParser()) self.getattr("parser").parse(args, self, default_config, no_default_config_action) if boot: @@ -330,7 +282,7 @@ def parse_config( {'a': 1, 'b': 2, 'c': 3} """ - if not self.hasattr("parser"): + if self.getattr("parser") is None: self.setattr("parser", ConfigParser()) self.getattr("parser").parse_config(args, self, default_config, no_default_config_action) if boot: @@ -351,7 +303,7 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None: {'a': 1, 'c': 4, 'b': 2} """ - if not self.hasattr("parser"): + if self.getattr("parser") is None: self.setattr("parser", ConfigParser()) return self.getattr("parser").add_argument(*args, **kwargs) diff --git a/chanfig/configclasses.py b/chanfig/configclasses.py index f2395f30..0b9f0df1 100644 --- a/chanfig/configclasses.py +++ b/chanfig/configclasses.py @@ -17,11 +17,12 @@ from functools import wraps from typing import Any, Type +from warnings import warn from .config import Config -def configclass(cls=None, recursive: bool = False): +def configclass(cls=None): """ Construct a Config in [`dataclass`][dataclasses.dataclass] style. @@ -32,7 +33,6 @@ def configclass(cls=None, recursive: bool = False): Args: cls (Type[Any]): The class to be enhanced, provided directly if no parentheses are used. - recursive (bool): If True, recursively copy class attributes. Only applicable if used with parentheses. Returns: A modified class with Config functionalities or a decorator with bound parameters. @@ -52,24 +52,20 @@ def configclass(cls=None, recursive: bool = False): ) """ + warn( + "This decorator is deprecated and may be removed in the future release. " + "All chanfig classes will copy variable identified in `__annotations__` by default." + "This decorator is equivalent to inheriting from `Config`.", + PendingDeprecationWarning, + ) + def decorator(cls: Type[Any]): if not issubclass(cls, Config): config_cls = type(cls.__name__, (Config, cls), dict(cls.__dict__)) cls = config_cls - cls_init = cls.__init__ - - @wraps(cls_init) - def init(self, *args, **kwargs): - cls_init(self) - self.copy_class_attributes(recursive=recursive) - self.merge(*args, **kwargs) - - setattr(cls, "__init__", init) # noqa: B010 - return cls if cls is None: return decorator - else: - return decorator(cls) + return decorator(cls) diff --git a/chanfig/default_dict.py b/chanfig/default_dict.py index 02ac7efb..89e10168 100644 --- a/chanfig/default_dict.py +++ b/chanfig/default_dict.py @@ -57,7 +57,7 @@ class DefaultDict(FlatDict): TypeError: `default_factory=[]` must be Callable, but got . """ - default_factory: Optional[Callable] = None + default_factory = None def __init__( # pylint: disable=W1113 self, default_factory: Callable | None = None, *args: Any, **kwargs: Any @@ -82,12 +82,13 @@ def __missing__(self, name: Any, default=Null) -> Any: # pylint: disable=R1710 return default def __repr__(self) -> str: - if self.default_factory is None: + default_factory = self.getattr("default_factory", None) + if default_factory is None: return super().__repr__() super_repr = super().__repr__()[len(self.__class__.__name__) :] # noqa: E203 if len(super_repr) == 2: - return f"{self.__class__.__name__}({self.default_factory}, )" - return f"{self.__class__.__name__}({self.default_factory}," + super_repr[1:] + return f"{self.__class__.__name__}({default_factory}, )" + return f"{self.__class__.__name__}({default_factory}," + super_repr[1:] def add(self, name: Any): r""" @@ -116,7 +117,8 @@ def add(self, name: Any): Traceback (most recent call last): ValueError: Cannot add to a DefaultDict with no default_factory """ - if self.default_factory is None: + default_factory = self.getattr("default_factory", None) + if default_factory is None: raise ValueError("Cannot add to a DefaultDict with no default_factory") - self.set(name, self.default_factory()) # pylint: disable=E1102 + self.set(name, default_factory()) # pylint: disable=E1102 return self.get(name) diff --git a/chanfig/flat_dict.py b/chanfig/flat_dict.py index 64ca23aa..7f737bc8 100644 --- a/chanfig/flat_dict.py +++ b/chanfig/flat_dict.py @@ -181,7 +181,7 @@ class FlatDict(dict, metaclass=Dict): # pylint: disable=R0904 - indent: int = 2 + indent = 2 def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == 1: @@ -192,6 +192,33 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: arg = vars(arg) args = (arg,) super().__init__(*args, **kwargs) + self.move_class_attributes() + + def move_class_attributes(self, recursive: bool = True) -> Self: + r""" + Move class attributes to instance. + + Args: + recursive: + + Returns: + self: + """ + + def move_cls_attributes(cls: type) -> Mapping: + attributes = {} + for k in get_annotations(cls).keys(): + if k in cls.__dict__: + attributes[k] = cls.__dict__[k] + delattr(cls, k) + return attributes + + if recursive: + for cls in self.__class__.__mro__: + self.merge(move_cls_attributes(cls), overwrite=False) + else: + self.merge(move_cls_attributes(self.__class__), overwrite=False) + return self def __post_init__(self, *args, **kwargs) -> None: pass diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index ffd50aa4..a9e5688a 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -162,9 +162,9 @@ class NestedDict(DefaultDict): # pylint: disable=E1136 {'f': {'n': 'chang'}, 'i': {'d': 1013}} """ - convert_mapping: bool = False - delimiter: str = "." - fallback: bool = False + convert_mapping = False + delimiter = "." + fallback = False def __init__( self, diff --git a/chanfig/registry.py b/chanfig/registry.py index 53e21936..1d76341b 100644 --- a/chanfig/registry.py +++ b/chanfig/registry.py @@ -94,9 +94,9 @@ class Registry(NestedDict): (1, 0) """ - override: bool = False - key: str = "name" - default: Any = Null + override = False + key = "name" + default = Null def __init__( self, override: bool | None = None, key: str | None = None, fallback: bool | None = None, default: Any = None diff --git a/demo/config.py b/demo/config.py index d84be2c3..545b2634 100644 --- a/demo/config.py +++ b/demo/config.py @@ -17,14 +17,14 @@ import os -from chanfig import Config, Variable, configclass +from chanfig import Config, Variable -@configclass class DataloaderConfig: batch_size: int = 64 num_workers: int = 4 pin_memory: bool = True + attribute = "None" # this will not be copied to the config class TestConfig(Config): diff --git a/tests/test_config.py b/tests/test_config.py index 57c53ab0..f103b28b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -172,10 +172,3 @@ class TestConfigDict: def test_affinty(self): assert id(self.dict.a) == id(self.dict.b.a) == id(self.dict.c.a) == id(self.dict.d.a) - - def test_copy_class_attributes(self): - config = self.dict.copy_class_attributes(recursive=False) - assert config.child == 3 - assert "ancestor" not in config - config = self.dict.copy_class_attributes() - assert config.ancestor == 1 diff --git a/tests/test_configclass.py b/tests/test_configclass.py index 08329416..edc92d42 100644 --- a/tests/test_configclass.py +++ b/tests/test_configclass.py @@ -25,7 +25,7 @@ class AncestorConfig: seed: int = Variable(1013, help="random seed") -@configclass(recursive=True) +@configclass class ChildConfig(AncestorConfig): __test__ = False name: str = "CHANfiG" @@ -41,9 +41,4 @@ class Test: def test_configclass(self): config = TestConfig() assert config.name == "CHANfiG" - assert "seed" not in config - - def test_configclass_recursive(self): - config = ChildConfig() - assert config.name == "CHANfiG" assert "seed" in config