Skip to content

Commit

Permalink
make all classes "dataclass"es
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Aug 21, 2024
1 parent 6eb0f23 commit dc32dad
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 95 deletions.
58 changes: 5 additions & 53 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,63 +97,15 @@ class Config(NestedDict):
{'f': {'n': 'chang'}, 'i': {'d': 1013}}
"""

parser: None # ConfigParser, Python 3.7 does not support forward reference
frozen: bool
parser = None # ConfigParser, Python 3.7 does not support forward reference
frozen = False

def __init__(self, *args: Any, default_factory: Callable | None = None, **kwargs: Any):
if default_factory is None:
default_factory = Config
self.setattr("frozen", False)
super().__init__(*args, default_factory=default_factory, **kwargs)

def copy_class_attributes(self, recursive: bool = True) -> Self:
r"""
Copy class attributes to instance.
Args:
recursive:
Returns:
self:
Examples:
>>> class Ancestor(Config):
... a = 1
>>> class Parent(Ancestor):
... b = 2
>>> class Child(Parent):
... c = 3
>>> c = Child()
>>> c
Child(<class 'chanfig.config.Config'>, )
>>> c.copy_class_attributes(recursive=False)
Child(<class 'chanfig.config.Config'>,('c'): 3)
>>> c.copy_class_attributes() # doctest: +SKIP
Child(<class 'chanfig.config.Config'>,
('a'): 1,
('b'): 2,
('c'): 3
)
"""

def copy_cls_attributes(cls: type) -> Mapping:
return {
k: v
for k, v in cls.__dict__.items()
if k not in self
and not k.startswith("__")
and (not (isinstance(v, (property, staticmethod, classmethod)) or callable(v)))
}

if recursive:
for cls in self.__class__.__mro__:
if cls.__module__.startswith("chanfig"):
break
self.merge(copy_cls_attributes(cls), overwrite=False)
else:
self.merge(copy_cls_attributes(self.__class__), overwrite=False)
return self

def post(self) -> Self | None:
r"""
Post process of `Config`.
Expand Down Expand Up @@ -288,7 +240,7 @@ def parse(
{'a': 1, 'b': 2, 'c': 3}
"""

if not self.hasattr("parser"):
if self.getattr("parser") is None:
self.setattr("parser", ConfigParser())
self.getattr("parser").parse(args, self, default_config, no_default_config_action)
if boot:
Expand Down Expand Up @@ -330,7 +282,7 @@ def parse_config(
{'a': 1, 'b': 2, 'c': 3}
"""

if not self.hasattr("parser"):
if self.getattr("parser") is None:
self.setattr("parser", ConfigParser())
self.getattr("parser").parse_config(args, self, default_config, no_default_config_action)
if boot:
Expand All @@ -351,7 +303,7 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None:
{'a': 1, 'c': 4, 'b': 2}
"""

if not self.hasattr("parser"):
if self.getattr("parser") is None:
self.setattr("parser", ConfigParser())
return self.getattr("parser").add_argument(*args, **kwargs)

Expand Down
24 changes: 10 additions & 14 deletions chanfig/configclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

from functools import wraps
from typing import Any, Type
from warnings import warn

from .config import Config


def configclass(cls=None, recursive: bool = False):
def configclass(cls=None):
"""
Construct a Config in [`dataclass`][dataclasses.dataclass] style.
Expand All @@ -32,7 +33,6 @@ def configclass(cls=None, recursive: bool = False):
Args:
cls (Type[Any]): The class to be enhanced, provided directly if no parentheses are used.
recursive (bool): If True, recursively copy class attributes. Only applicable if used with parentheses.
Returns:
A modified class with Config functionalities or a decorator with bound parameters.
Expand All @@ -52,24 +52,20 @@ def configclass(cls=None, recursive: bool = False):
)
"""

warn(
"This decorator is deprecated and may be removed in the future release. "
"All chanfig classes will copy variable identified in `__annotations__` by default."
"This decorator is equivalent to inheriting from `Config`.",
PendingDeprecationWarning,
)

def decorator(cls: Type[Any]):
if not issubclass(cls, Config):
config_cls = type(cls.__name__, (Config, cls), dict(cls.__dict__))
cls = config_cls

cls_init = cls.__init__

@wraps(cls_init)
def init(self, *args, **kwargs):
cls_init(self)
self.copy_class_attributes(recursive=recursive)
self.merge(*args, **kwargs)

setattr(cls, "__init__", init) # noqa: B010

return cls

if cls is None:
return decorator
else:
return decorator(cls)
return decorator(cls)
14 changes: 8 additions & 6 deletions chanfig/default_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DefaultDict(FlatDict):
TypeError: `default_factory=[]` must be Callable, but got <class 'list'>.
"""

default_factory: Optional[Callable] = None
default_factory = None

def __init__( # pylint: disable=W1113
self, default_factory: Callable | None = None, *args: Any, **kwargs: Any
Expand All @@ -82,12 +82,13 @@ def __missing__(self, name: Any, default=Null) -> Any: # pylint: disable=R1710
return default

def __repr__(self) -> str:
if self.default_factory is None:
default_factory = self.getattr("default_factory", None)
if default_factory is None:
return super().__repr__()
super_repr = super().__repr__()[len(self.__class__.__name__) :] # noqa: E203
if len(super_repr) == 2:
return f"{self.__class__.__name__}({self.default_factory}, )"
return f"{self.__class__.__name__}({self.default_factory}," + super_repr[1:]
return f"{self.__class__.__name__}({default_factory}, )"
return f"{self.__class__.__name__}({default_factory}," + super_repr[1:]

def add(self, name: Any):
r"""
Expand Down Expand Up @@ -116,7 +117,8 @@ def add(self, name: Any):
Traceback (most recent call last):
ValueError: Cannot add to a DefaultDict with no default_factory
"""
if self.default_factory is None:
default_factory = self.getattr("default_factory", None)
if default_factory is None:
raise ValueError("Cannot add to a DefaultDict with no default_factory")
self.set(name, self.default_factory()) # pylint: disable=E1102
self.set(name, default_factory()) # pylint: disable=E1102
return self.get(name)
29 changes: 28 additions & 1 deletion chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class FlatDict(dict, metaclass=Dict):

# pylint: disable=R0904

indent: int = 2
indent = 2

def __init__(self, *args: Any, **kwargs: Any) -> None:
if len(args) == 1:
Expand All @@ -192,6 +192,33 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
arg = vars(arg)
args = (arg,)
super().__init__(*args, **kwargs)
self.move_class_attributes()

def move_class_attributes(self, recursive: bool = True) -> Self:
r"""
Move class attributes to instance.
Args:
recursive:
Returns:
self:
"""

def move_cls_attributes(cls: type) -> Mapping:
attributes = {}
for k in get_annotations(cls).keys():
if k in cls.__dict__:
attributes[k] = cls.__dict__[k]
delattr(cls, k)
return attributes

if recursive:
for cls in self.__class__.__mro__:
self.merge(move_cls_attributes(cls), overwrite=False)
else:
self.merge(move_cls_attributes(self.__class__), overwrite=False)
return self

def __post_init__(self, *args, **kwargs) -> None:
pass
Expand Down
6 changes: 3 additions & 3 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ class NestedDict(DefaultDict): # pylint: disable=E1136
{'f': {'n': 'chang'}, 'i': {'d': 1013}}
"""

convert_mapping: bool = False
delimiter: str = "."
fallback: bool = False
convert_mapping = False
delimiter = "."
fallback = False

def __init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions chanfig/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ class Registry(NestedDict):
(1, 0)
"""

override: bool = False
key: str = "name"
default: Any = Null
override = False
key = "name"
default = Null

def __init__(
self, override: bool | None = None, key: str | None = None, fallback: bool | None = None, default: Any = None
Expand Down
4 changes: 2 additions & 2 deletions demo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import os

from chanfig import Config, Variable, configclass
from chanfig import Config, Variable


@configclass
class DataloaderConfig:
batch_size: int = 64
num_workers: int = 4
pin_memory: bool = True
attribute = "None" # this will not be copied to the config


class TestConfig(Config):
Expand Down
7 changes: 0 additions & 7 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,3 @@ class TestConfigDict:

def test_affinty(self):
assert id(self.dict.a) == id(self.dict.b.a) == id(self.dict.c.a) == id(self.dict.d.a)

def test_copy_class_attributes(self):
config = self.dict.copy_class_attributes(recursive=False)
assert config.child == 3
assert "ancestor" not in config
config = self.dict.copy_class_attributes()
assert config.ancestor == 1
7 changes: 1 addition & 6 deletions tests/test_configclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AncestorConfig:
seed: int = Variable(1013, help="random seed")


@configclass(recursive=True)
@configclass
class ChildConfig(AncestorConfig):
__test__ = False
name: str = "CHANfiG"
Expand All @@ -41,9 +41,4 @@ class Test:
def test_configclass(self):
config = TestConfig()
assert config.name == "CHANfiG"
assert "seed" not in config

def test_configclass_recursive(self):
config = ChildConfig()
assert config.name == "CHANfiG"
assert "seed" in config

0 comments on commit dc32dad

Please sign in to comment.