Skip to content

Commit

Permalink
make apply traverse all data container in NestedDict
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed May 20, 2023
1 parent 87a9585 commit 7ca8564
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
52 changes: 44 additions & 8 deletions chanfig/nested_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

from functools import wraps
from inspect import ismethod
from os import PathLike
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, Optional, Tuple, Union

Expand All @@ -29,6 +30,41 @@
from torch import dtype as TorchDtype


def apply(obj: Any, func: Callable, *args, **kwargs) -> Any:
r"""
Apply `func` to all children of `obj`.
Note that this method is meant for non-in-place modification of `obj` and should return a new object.
Args:
obj: Object to apply function.
func: Function to be applied.
*args: Positional arguments to be passed to `func`.
**kwargs: Keyword arguments to be passed to `func`.
Returns:
(Any): Return value of `func`.
See Also:
[`apply_`][chanfig.nested_dict.apply_]: Apply a in-place operation.
"""

if isinstance(obj, Mapping):
{k: apply(v, func, *args, **kwargs) for k, v in obj.items()}
if isinstance(obj, list):
[apply(v, func, *args, **kwargs) for v in obj] # type: ignore
if isinstance(obj, tuple):
tuple(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
if isinstance(obj, set):
try:
set(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
except TypeError:
tuple(apply(v, func, *args, **kwargs) for v in obj) # type: ignore
if isinstance(obj, NestedDict):
return func(*args, **kwargs) if ismethod(func) else func(obj, *args, **kwargs)
return obj


class NestedDict(DefaultDict):
r"""
`NestedDict` further extends `DefaultDict` object by introducing a nested structure with `delimiter`.
Expand Down Expand Up @@ -179,18 +215,18 @@ def apply(self, func: Callable, *args, **kwargs) -> NestedDict:
func(Callable):
Examples:
>>> d = NestedDict()
>>> d.a = NestedDict()
>>> def func(d):
... d.t = 1
>>> d = NestedDict()
>>> d.a = NestedDict()
>>> d.b = [NestedDict(),]
>>> d.c = (NestedDict(),)
>>> d.d = {NestedDict(),}
>>> d.apply(func).dict()
{'a': {'t': 1}, 't': 1}
{'a': {'t': 1}, 'b': [{'t': 1}], 'c': ({'t': 1},), 'd': ({'t': 1},), 't': 1}
"""

for value in self.values():
if isinstance(value, NestedDict):
value.apply(func, *args, **kwargs)
return func(self, *args, **kwargs) or self
return apply(self, func, *args, **kwargs) or self

def get(self, name: Any, default: Any = Null) -> Any:
r"""
Expand Down Expand Up @@ -554,7 +590,7 @@ def to(self, cls: Union[str, TorchDevice, TorchDtype]) -> Any:
{'i': {'d': tensor(1013)}}
"""

return self.apply(lambda _: super().to(cls))
return self.apply(super().to, cls)

def dropnull(self) -> NestedDict:
r"""
Expand Down
3 changes: 2 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import partial

from chanfig import Config, Variable
from torch import nn

from chanfig import Config, Variable


class Model:
def __init__(self, encoder, decoder, dropout=0.1, activation="ReLU"):
Expand Down

0 comments on commit 7ca8564

Please sign in to comment.