Skip to content

Commit

Permalink
rewrite apply and apply_ for better extensibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed May 20, 2023
1 parent 7ca8564 commit 02ee44f
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 83 deletions.
50 changes: 8 additions & 42 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,14 @@
from __future__ import annotations

import sys
from argparse import ArgumentParser, Namespace, _StoreAction
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, Iterable, Optional, Sequence
from warnings import warn

from .nested_dict import NestedDict
from .utils import Null


class StoreAction(_StoreAction): # pylint: disable=R0903
def __init__( # pylint: disable=R0913
self,
option_strings,
dest,
nargs=None,
const=None,
default=Null,
type=None, # pylint: disable=W0622
choices=None,
required=False,
help=None, # pylint: disable=W0622
metavar=None,
):
if dest is not None and type is not None:
warn(f"type of argument {dest} is set to {type}, but CHANfiG will ignore it.")
type = None
super().__init__(
option_strings=option_strings,
dest=dest,
nargs=nargs,
const=const,
default=default,
type=type,
choices=choices,
required=required,
help=help,
metavar=metavar,
)
if self.default is not Null:
warn(
f"Default value for argument {self.dest} is set to {self.default}, "
"Default value defined in argument will be overwritten by default value defined in Config",
)
from .utils import Null, StoreAction


class ConfigParser(ArgumentParser): # pylint: disable=C0115
Expand Down Expand Up @@ -456,10 +420,11 @@ def freeze(self, recursive: bool = True) -> Config:

@wraps(self.freeze)
def freeze(config: Config) -> None:
config.setattr("frozen", True)
if isinstance(config, Config):
config.setattr("frozen", True)

if recursive:
self.apply(freeze)
self.apply_(freeze)
else:
freeze(self)
return self
Expand Down Expand Up @@ -527,10 +492,11 @@ def defrost(self, recursive: bool = True) -> Config:

@wraps(self.defrost)
def defrost(config: Config) -> None:
config.setattr("frozen", False)
if isinstance(config, Config):
config.setattr("frozen", False)

if recursive:
self.apply(defrost)
self.apply_(defrost)
else:
defrost(self)
return self
Expand Down
3 changes: 3 additions & 0 deletions chanfig/flat_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def dict(self, cls: Callable = dict) -> Mapping:
Returns:
(Mapping):
See Also:
[`to_dict`][chanfig.flat_dict.to_dict]: implementation of `dict` method.
Examples:
>>> d = FlatDict(a=1, b=2, c=3)
>>> d.dict()
Expand Down
86 changes: 46 additions & 40 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,18 @@
from __future__ import annotations

from functools import wraps
from inspect import ismethod
from os import PathLike
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, Optional, Tuple, Union

from .default_dict import DefaultDict
from .flat_dict import PathStr
from .utils import Null
from .utils import Null, apply, apply_

if TYPE_CHECKING:
from torch import device as TorchDevice
from torch import dtype as TorchDtype


def apply(obj: Any, func: Callable, *args, **kwargs) -> Any:
r"""
Apply `func` to all children of `obj`.
Note that this method is meant for non-in-place modification of `obj` and should return a new object.
Args:
obj: Object to apply function.
func: Function to be applied.
*args: Positional arguments to be passed to `func`.
**kwargs: Keyword arguments to be passed to `func`.
Returns:
(Any): Return value of `func`.
See Also:
[`apply_`][chanfig.nested_dict.apply_]: Apply a in-place operation.
"""

if isinstance(obj, Mapping):
{k: apply(v, func, *args, **kwargs) for k, v in obj.items()}
if isinstance(obj, list):
[apply(v, func, *args, **kwargs) for v in obj] # type: ignore
if isinstance(obj, tuple):
tuple(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
if isinstance(obj, set):
try:
set(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
except TypeError:
tuple(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
if isinstance(obj, NestedDict):
return func(*args, **kwargs) if ismethod(func) else func(obj, *args, **kwargs)
return obj


class NestedDict(DefaultDict):
r"""
`NestedDict` further extends `DefaultDict` object by introducing a nested structure with `delimiter`.
Expand Down Expand Up @@ -211,22 +175,60 @@ def apply(self, func: Callable, *args, **kwargs) -> NestedDict:
r"""
Recursively apply a function to `NestedDict` and its children.
Note:
This method is meant for non-in-place modification of `obj`, for example, [`to`][chanfig.NestedDict.to].
Args:
func(Callable):
See Also:
[`apply_`][chanfig.NestedDict.apply_]: Apply a in-place operation.
[`apply`][chanfig.utils.apply]: implementation of `apply` method.
Examples:
>>> def func(d):
... d.t = 1
... if isinstance(d, NestedDict):
... d.t = 1
>>> d = NestedDict()
>>> d.a = NestedDict()
>>> d.b = [NestedDict(),]
>>> d.c = (NestedDict(),)
>>> d.d = {NestedDict(),}
>>> d.apply(func).dict()
{'a': {}, 'b': [{}], 'c': ({},), 'd': ({},)}
"""

return apply(self, func, *args, **kwargs)

def apply_(self, func: Callable, *args, **kwargs) -> NestedDict:
r"""
Recursively apply a function to `NestedDict` and its children.
Note:
This method is meant for in-place modification of `obj`, for example, [`freeze`][chanfig.Config.freeze].
Args:
func(Callable):
See Also:
[`apply`][chanfig.NestedDict.apply]: Apply a non-in-place operation.
[`apply_`][chanfig.utils.apply_]: implementation of `apply_` method.
Examples:
>>> def func(d):
... if isinstance(d, NestedDict):
... d.t = 1
>>> d = NestedDict()
>>> d.a = NestedDict()
>>> d.b = [NestedDict(),]
>>> d.c = (NestedDict(),)
>>> d.d = {NestedDict(),}
>>> d.apply_(func).dict()
{'a': {'t': 1}, 'b': [{'t': 1}], 'c': ({'t': 1},), 'd': ({'t': 1},), 't': 1}
"""

return apply(self, func, *args, **kwargs) or self
apply_(self, func, *args, **kwargs)
return self

def get(self, name: Any, default: Any = Null) -> Any:
r"""
Expand Down Expand Up @@ -590,7 +592,11 @@ def to(self, cls: Union[str, TorchDevice, TorchDtype]) -> Any:
{'i': {'d': tensor(1013)}}
"""

return self.apply(super().to, cls)
def to(obj):
if hasattr(obj, "to"):
return obj.to(cls)

return self.apply(to)

def dropnull(self) -> NestedDict:
r"""
Expand Down
98 changes: 97 additions & 1 deletion chanfig/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,74 @@
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the LICENSE file for more details.

from argparse import _StoreAction
from inspect import ismethod
from json import JSONEncoder
from typing import Any, Mapping
from typing import Any, Callable, Mapping
from warnings import warn

from yaml import SafeDumper, SafeLoader

from .variable import Variable


def apply(obj: Any, func: Callable, *args, **kwargs) -> Any:
r"""
Apply `func` to all children of `obj`.
Note that this function is meant for non-in-place modification of `obj` and should return the original object.
Args:
obj: Object to apply function.
func: Function to be applied.
*args: Positional arguments to be passed to `func`.
**kwargs: Keyword arguments to be passed to `func`.
Returns:
(Any): Return value of `func`.
See Also:
[`apply_`][chanfig.utils.apply_]: Apply a in-place operation.
"""

if isinstance(obj, Mapping):
return type(obj)({k: apply(v, func, *args, **kwargs) for k, v in obj.items()}) # type: ignore
if isinstance(obj, (list, tuple)):
return type(obj)(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
if isinstance(obj, set):
try:
return type(obj)(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
except TypeError:
tuple(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
return func(*args, **kwargs) if ismethod(func) else func(obj, *args, **kwargs)


def apply_(obj: Any, func: Callable, *args, **kwargs) -> Any:
r"""
Apply `func` to all children of `obj`.
Note that this function is meant for non-in-place modification of `obj` and should return a new object.
Args:
obj: Object to apply function.
func: Function to be applied.
*args: Positional arguments to be passed to `func`.
**kwargs: Keyword arguments to be passed to `func`.
Returns:
(Any): Return value of `func`.
See Also:
[`apply_`][chanfig.utils.apply_]: Apply a in-place operation.
"""

if isinstance(obj, Mapping):
[apply_(v, func, *args, **kwargs) for v in obj.values()] # type: ignore
if isinstance(obj, (list, tuple, set)):
[apply_(v, func, *args, **kwargs) for v in obj] # type: ignore
return func(*args, **kwargs) if ismethod(func) else func(obj, *args, **kwargs)


class Singleton(type):
r"""
Metaclass for Singleton Classes.
Expand Down Expand Up @@ -102,3 +162,39 @@ class YamlLoader(SafeLoader): # pylint: disable=R0901,R0903


Null = NULL()


class StoreAction(_StoreAction): # pylint: disable=R0903
def __init__( # pylint: disable=R0913
self,
option_strings,
dest,
nargs=None,
const=None,
default=Null,
type=None, # pylint: disable=W0622
choices=None,
required=False,
help=None, # pylint: disable=W0622
metavar=None,
):
if dest is not None and type is not None:
warn(f"type of argument {dest} is set to {type}, but CHANfiG will ignore it.")
type = None
super().__init__(
option_strings=option_strings,
dest=dest,
nargs=nargs,
const=const,
default=default,
type=type,
choices=choices,
required=required,
help=help,
metavar=metavar,
)
if self.default is not Null:
warn(
f"Default value for argument {self.dest} is set to {self.default}, "
"Default value defined in argument will be overwritten by default value defined in Config",
)

0 comments on commit 02ee44f

Please sign in to comment.