From 7ca856401a8ff049f40f134911f65fe546a5818a Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sat, 20 May 2023 18:03:08 +0800 Subject: [PATCH] make apply traverse all data container in NestedDict --- chanfig/nested_dict.py | 52 +++++++++++++++++++++++++++++++++++------- tests/test.py | 3 ++- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/chanfig/nested_dict.py b/chanfig/nested_dict.py index c3636829..b82bd5ac 100644 --- a/chanfig/nested_dict.py +++ b/chanfig/nested_dict.py @@ -17,6 +17,7 @@ from __future__ import annotations from functools import wraps +from inspect import ismethod from os import PathLike from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, Optional, Tuple, Union @@ -29,6 +30,41 @@ from torch import dtype as TorchDtype +def apply(obj: Any, func: Callable, *args, **kwargs) -> Any: + r""" + Apply `func` to all children of `obj`. + + Note that this method is meant for non-in-place modification of `obj` and should return a new object. + + Args: + obj: Object to apply function. + func: Function to be applied. + *args: Positional arguments to be passed to `func`. + **kwargs: Keyword arguments to be passed to `func`. + + Returns: + (Any): Return value of `func`. + + See Also: + [`apply_`][chanfig.nested_dict.apply_]: Apply a in-place operation. + """ + + if isinstance(obj, Mapping): + {k: apply(v, func, *args, **kwargs) for k, v in obj.items()} + if isinstance(obj, list): + [apply(v, func, *args, **kwargs) for v in obj] # type: ignore + if isinstance(obj, tuple): + tuple(apply(v, func, *args, **kwargs) for v in obj) # type: ignore + if isinstance(obj, set): + try: + set(apply(v, func, *args, **kwargs) for v in obj) # type: ignore + except TypeError: + tuple(apply(v, func, *args, **kwargs) for v in obj) # type: ignore + if isinstance(obj, NestedDict): + return func(*args, **kwargs) if ismethod(func) else func(obj, *args, **kwargs) + return obj + + class NestedDict(DefaultDict): r""" `NestedDict` further extends `DefaultDict` object by introducing a nested structure with `delimiter`. @@ -179,18 +215,18 @@ def apply(self, func: Callable, *args, **kwargs) -> NestedDict: func(Callable): Examples: - >>> d = NestedDict() - >>> d.a = NestedDict() >>> def func(d): ... d.t = 1 + >>> d = NestedDict() + >>> d.a = NestedDict() + >>> d.b = [NestedDict(),] + >>> d.c = (NestedDict(),) + >>> d.d = {NestedDict(),} >>> d.apply(func).dict() - {'a': {'t': 1}, 't': 1} + {'a': {'t': 1}, 'b': [{'t': 1}], 'c': ({'t': 1},), 'd': ({'t': 1},), 't': 1} """ - for value in self.values(): - if isinstance(value, NestedDict): - value.apply(func, *args, **kwargs) - return func(self, *args, **kwargs) or self + return apply(self, func, *args, **kwargs) or self def get(self, name: Any, default: Any = Null) -> Any: r""" @@ -554,7 +590,7 @@ def to(self, cls: Union[str, TorchDevice, TorchDtype]) -> Any: {'i': {'d': tensor(1013)}} """ - return self.apply(lambda _: super().to(cls)) + return self.apply(super().to, cls) def dropnull(self) -> NestedDict: r""" diff --git a/tests/test.py b/tests/test.py index 92fe7ebb..7e81836c 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,8 +1,9 @@ from functools import partial -from chanfig import Config, Variable from torch import nn +from chanfig import Config, Variable + class Model: def __init__(self, encoder, decoder, dropout=0.1, activation="ReLU"):