Skip to content

Commit

Permalink
allow no super().__init__() call in childern class
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Jan 16, 2023
1 parent 1dee88e commit 36f77ca
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 33 deletions.
14 changes: 10 additions & 4 deletions .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
- uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }}
- name: Install dependencies
run: pip install -r requirements.txt
run: pip install -r requirements.txt && pip install -e .
- name: Install dependencies for testing
run: pip install torch
- name: Install pytest
run: pip install pytest
run: pip install pytest torch
- name: doctest
run: pytest --doctest-modules chanfig
- name: unittest
run: pytest tests/
release:
if: startsWith(github.event.ref, 'refs/tags/v')
needs: [lint, test]
Expand All @@ -46,6 +50,7 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: 3.x
cache: "pip"
- uses: Gr1N/setup-poetry@v7
- name: install building dependencies
run: poetry self add "poetry-dynamic-versioning[plugin]"
Expand Down Expand Up @@ -79,6 +84,7 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: 3.x
cache: "pip"
- uses: Gr1N/setup-poetry@v7
- name: install building dependencies
run: poetry self add "poetry-dynamic-versioning[plugin]"
Expand Down
18 changes: 9 additions & 9 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def frozen_check(func: Callable):

@wraps(func)
def decorator(self, *args, **kwargs):
if self.getattr("frozen"):
if self.getattr("frozen", False):
raise ValueError("Attempting to alter a frozen config. Run config.defrost() to defrost first")
return func(self, *args, **kwargs)

Expand All @@ -110,18 +110,16 @@ class Config(NestedDict):
It is recommended to call `config.freeze()` or `config.to(NestedDict)` to avoid this behavior.
"""

default_factory: Optional[Callable]
parser: ConfigParser
frozen: bool = False
convert_mapping: bool = True
delimiter: str = "."
indent: int = 2
parser: ConfigParser
default_factory: Optional[Callable]

def __init__(self, *args, **kwargs):
self.setattr("frozen", False)
self.setattr("parser", ConfigParser())
super().__init__(*args, default_factory=Config, **kwargs)
self.setattr("convert_mapping", True)

def get(self, name: str, default: Optional[Any] = None) -> Any:
r"""
Expand Down Expand Up @@ -163,7 +161,9 @@ def get(self, name: str, default: Optional[Any] = None) -> Any:
```
"""

if name in self or not self.getattr("frozen"):
if "default_factory" not in self: # did not call super().__init__() in sub-class
self.setattr("default_factory", Config)
if name in self or not self.getattr("frozen", False):
return super().get(name, default)
raise KeyError(f"{self.__class__.__name__} does not contain {name}")

Expand Down Expand Up @@ -373,7 +373,7 @@ def unlocked(self):
```
"""

was_frozen = self.getattr("frozen")
was_frozen = self.getattr("frozen", False)
try:
self.defrost()
yield self
Expand Down Expand Up @@ -411,7 +411,7 @@ def parse(
```
"""

return self.getattr("parser").parse(args, self, default_config)
return self.getattr("parser", ConfigParser()).parse(args, self, default_config)

parse_config = parse

Expand All @@ -429,4 +429,4 @@ def add_argument(self, *args, **kwargs) -> None:
```
"""

self.getattr("parser").add_argument(*args, **kwargs)
self.getattr("parser", ConfigParser()).add_argument(*args, **kwargs)
10 changes: 5 additions & 5 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class FlatDict(OrderedDict):

# pylint: disable=R0904

default_factory: Optional[Callable]
indent: int = 2
default_factory: Optional[Callable]

def __init__(self, *args, default_factory: Optional[Callable] = None, **kwargs):
super().__init__()
Expand All @@ -92,7 +92,6 @@ def __init__(self, *args, default_factory: Optional[Callable] = None, **kwargs):
raise TypeError(
f"default_factory={default_factory} should be of type Callable, but got {type(default_factory)}"
)
self.setattr("indent", 2)
self._init(*args, **kwargs)

def _init(self, *args, **kwargs) -> None:
Expand All @@ -106,6 +105,7 @@ def _init(self, *args, **kwargs) -> None:
**kwargs: {key1: value1, key2: value2}.
"""

self.setattr("indent", 2)
for key, value in args:
self.set(key, value)
for key, value in kwargs.items():
Expand Down Expand Up @@ -716,7 +716,7 @@ def jsons(self, *args, **kwargs) -> str:
if "cls" not in kwargs:
kwargs["cls"] = JsonEncoder
if "indent" not in kwargs:
kwargs["indent"] = self.getattr("indent")
kwargs["indent"] = self.getattr("indent", 2)
return json_dumps(self.dict(), *args, **kwargs)

@classmethod
Expand Down Expand Up @@ -788,7 +788,7 @@ def yamls(self, *args, **kwargs) -> str:
if "Dumper" not in kwargs:
kwargs["Dumper"] = YamlDumper
if "indent" not in kwargs:
kwargs["indent"] = self.getattr("indent")
kwargs["indent"] = self.getattr("indent", 2)
return yaml_dump(self.dict(), *args, **kwargs) # type: ignore

@classmethod
Expand Down Expand Up @@ -923,7 +923,7 @@ def _add_indent(self, text):
if len(lines) == 1:
return text
first = lines.pop(0)
lines = [(self.getattr("indent") * " ") + line for line in lines]
lines = [(self.getattr("indent", 2) * " ") + line for line in lines]
lines = "\n".join(lines)
lines = first + "\n" + lines
return lines
Expand Down
36 changes: 21 additions & 15 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@ class NestedDict(FlatDict):
```
"""

default_mapping: Callable
convert_mapping: bool = False
default_factory: Optional[Callable]
default_mapping: Optional[Callable]
delimiter: str = "."
indent: int = 2
default_factory: Optional[Callable]

def __init__(self, *args, default_factory: Optional[Callable] = None, **kwargs):
self.setattr("delimiter", ".")
self.setattr("convert_mapping", False)
self.setattr("default_mapping", NestedDict)
super().__init__(*args, default_factory=default_factory, **kwargs)

Expand Down Expand Up @@ -89,8 +87,9 @@ def get(self, name: str, default: Optional[Any] = None) -> Any:
```
"""

while self.getattr("delimiter") in name:
name, rest = name.split(self.getattr("delimiter"), 1)
delimiter = self.getattr("delimiter", ".")
while delimiter in name:
name, rest = name.split(delimiter, 1)
self, name = self[name], rest # pylint: disable=W0642
return super().get(name, default)

Expand Down Expand Up @@ -138,11 +137,12 @@ def set( # pylint: disable=W0221
```
"""

default_mapping = self.getattr("default_mapping")
delimiter = self.getattr("delimiter", ".")
default_mapping = self.getattr("default_mapping", NestedDict)
if convert_mapping is None:
convert_mapping = self.convert_mapping
while self.getattr("delimiter") in name:
name, rest = name.split(self.getattr("delimiter"), 1)
while delimiter in name:
name, rest = name.split(delimiter, 1)
if name not in self:
if convert_mapping:
super().__setitem__(name, default_mapping())
Expand Down Expand Up @@ -178,8 +178,9 @@ def __contains__(self, name: str) -> bool: # type: ignore
```
"""

while self.getattr("delimiter") in name:
name, rest = name.split(self.getattr("delimiter"), 1)
delimiter = self.getattr("delimiter", ".")
while delimiter in name:
name, rest = name.split(delimiter, 1)
self, name = self[name], rest # pylint: disable=W0642
return super().__contains__(name)

Expand All @@ -206,8 +207,9 @@ def pop(self, name: str, default: Optional[Any] = None) -> Any:
```
"""

if self.getattr("delimiter") in name:
name, rest = name.split(self.getattr("delimiter"), 1)
delimiter = self.getattr("delimiter", ".")
if delimiter in name:
name, rest = name.split(delimiter, 1)
if name not in self:
raise KeyError(f"{self.__class__.__name__} does not contain {name}")
return self[name].pop(rest, default)
Expand All @@ -226,11 +228,13 @@ def all_keys(self):
```
"""

delimiter = self.getattr("delimiter", ".")

@wraps(self.all_keys)
def all_keys(self, prefix=""):
for key, value in self.items():
if prefix:
key = prefix + self.getattr("delimiter") + key
key = prefix + delimiter + key
if isinstance(value, NestedDict):
yield from all_keys(value, key)
else:
Expand Down Expand Up @@ -270,11 +274,13 @@ def all_items(self):
```
"""

delimiter = self.getattr("delimiter", ".")

@wraps(self.all_items)
def all_items(self, prefix=""):
for key, value in self.items():
if prefix:
key = prefix + self.getattr("delimiter") + key
key = prefix + delimiter + key
if isinstance(value, NestedDict):
yield from all_items(value, key)
else:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from chanfig import Config, Variable


class TestConfig(Config):
__test__ = False
def __init__(self, *args, **kwargs):
num_classes = Variable(10)
self.name = "CHANfiG"
self.seed = Variable(1013)
self.network.name = "ResNet18"
self.network.num_classes = num_classes
self.dataset.num_classes = num_classes


class Test:

config = TestConfig()

def test_value(self):
assert self.config.name == "CHANfiG"

def test_nested(self):
assert self.config.network.name == "ResNet18"

def test_variable(self):
assert self.config.network.num_classes == 10
self.config.network.num_classes += 1
assert self.config.dataset.num_classes == 11

def test_fstring(self):
assert f"seed{self.config.seed}" == "seed1013"

0 comments on commit 36f77ca

Please sign in to comment.