From 2ccb8548bf271cb848c192d775f37dc9b45d818d Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 28 Mar 2024 19:05:48 +0800 Subject: [PATCH] add ConfigRegistry Signed-off-by: Zhiyuan Chen --- chanfig/__init__.py | 7 ++- chanfig/config.py | 6 +- chanfig/registry.py | 126 +++++++++++++++++++++++++++++++++++++++++- docs/docs/registry.md | 2 +- 4 files changed, 136 insertions(+), 5 deletions(-) diff --git a/chanfig/__init__.py b/chanfig/__init__.py index 91cd7f81..2625fd2c 100755 --- a/chanfig/__init__.py +++ b/chanfig/__init__.py @@ -27,7 +27,7 @@ from .functional import load, save from .nested_dict import NestedDict, apply, apply_ from .parser import ConfigParser -from .registry import GlobalRegistry, Registry +from .registry import ConfigRegistry, GlobalRegistry, Registry from .variable import Variable __all__ = [ @@ -37,6 +37,7 @@ "NestedDict", "FlatDict", "Registry", + "ConfigRegistry", "GlobalRegistry", "DefaultDict", "ConfigParser", @@ -56,7 +57,11 @@ add_representer(NestedDict, SafeRepresenter.represent_dict) add_representer(DefaultDict, SafeRepresenter.represent_dict) add_representer(Config, SafeRepresenter.represent_dict) +add_representer(Registry, SafeRepresenter.represent_dict) +add_representer(ConfigRegistry, SafeRepresenter.represent_dict) SafeRepresenter.add_representer(FlatDict, SafeRepresenter.represent_dict) SafeRepresenter.add_representer(NestedDict, SafeRepresenter.represent_dict) SafeRepresenter.add_representer(DefaultDict, SafeRepresenter.represent_dict) SafeRepresenter.add_representer(Config, SafeRepresenter.represent_dict) +SafeRepresenter.add_representer(Registry, SafeRepresenter.represent_dict) +SafeRepresenter.add_representer(ConfigRegistry, SafeRepresenter.represent_dict) diff --git a/chanfig/config.py b/chanfig/config.py index 6fea7ebe..444b518a 100755 --- a/chanfig/config.py +++ b/chanfig/config.py @@ -164,7 +164,8 @@ def post(self) -> Self | None: Note that you should always call `boot` to apply `post` rather than calling `post` directly, as `boot` recursively call `post` on sub-configs. - See Also: [`boot`][chanfig.Config.boot] + See Also: + [`boot`][chanfig.Config.boot] Returns: self: @@ -199,7 +200,8 @@ def boot(self) -> Self: By default, `boot` is called after `Config` is parsed. If you don't need to parse command-line arguments, you should call `boot` manually. - See Also: [`post`][chanfig.Config.post] + See Also: + [`post`][chanfig.Config.post] Returns: self: diff --git a/chanfig/registry.py b/chanfig/registry.py index 25be85bc..f8e512e2 100644 --- a/chanfig/registry.py +++ b/chanfig/registry.py @@ -27,7 +27,7 @@ class Registry(NestedDict): """ `Registry` for components. - Registry provides 3 core functionalities: + `Registry` provides 3 core functionalities: - Register a new component. - Lookup for a component. @@ -55,6 +55,9 @@ class Registry(NestedDict): You could create a sub-registry by simply calling `registry.sub_registry = Registry`, and access through `registry.sub_registry.register()`. + See Also: + [`ConfigRegistry`][chanfig.ConfigRegistry]: Optimised for components that can be initialised with a `config`. + Examples: >>> registry = Registry() >>> @registry.register @@ -254,4 +257,125 @@ def build(self, name: str | MutableMapping | None = None, *args: Any, **kwargs: return self.init(self.lookup(name), *args, **kwargs) # type: ignore[arg-type] +class ConfigRegistry(Registry): + """ + `ConfigRegistry` for components that can be initialised with a `config`. + + `ConfigRegistry` is purcutularly useful when you want to construct a component from a configuration, such as a + Hugginface Transformers model. + + See Also: + [`Registry`][chanfig.Registry]: General purpose Registry. + + Examples: + >>> from dataclasses import dataclass, field + >>> @dataclass + ... class Config: + ... a: int + ... b: int + ... mode: str = "proj" + >>> registry = ConfigRegistry(key="mode") + >>> @registry.register("proj") + ... class Proj: + ... def __init__(self, config): + ... self.a = config.a + ... self.b = config.b + >>> @registry.register("inv") + ... class Inv: + ... def __init__(self, config): + ... self.a = config.b + ... self.b = config.a + >>> registry + ConfigRegistry( + ('proj'): + ('inv'): + ) + >>> config = Config(a=0, b=1) + >>> module = registry.build(config) + >>> module.a, module.b + (0, 1) + >>> config = Config(a=0, b=1, mode="inv") + >>> module = registry.build(config) + >>> module.a, module.b + (1, 0) + >>> @dataclass + ... class ModuleConfig: + ... a: int = 0 + ... b: int = 1 + ... mode: str = "proj" + >>> @dataclass + ... class NestedConfig: + ... module: ModuleConfig = field(default_factory=ModuleConfig) + >>> nested_registry = ConfigRegistry(key="module.mode") + >>> @nested_registry.register("proj") + ... class Proj: + ... def __init__(self, config): + ... self.a = config.module.a + ... self.b = config.module.b + >>> @nested_registry.register("inv") + ... class Inv: + ... def __init__(self, config): + ... self.a = config.module.b + ... self.b = config.module.a + >>> nested_config = NestedConfig() + >>> module = nested_registry.build(nested_config) + >>> module.a, module.b + (0, 1) + """ + + def build(self, config) -> Any: # type: ignore[override] + r""" + Build a component. + + Args: + config + + Returns: + (Any): + + Raises: + KeyError: If the component is not registered. + + Examples: + >>> from dataclasses import dataclass, field + >>> registry = ConfigRegistry(key="module.mode") + >>> @registry.register("proj") + ... class Proj: + ... def __init__(self, config): + ... self.a = config.module.a + ... self.b = config.module.b + >>> @registry.register("inv") + ... class Inv: + ... def __init__(self, config): + ... self.a = config.module.b + ... self.b = config.module.a + >>> @dataclass + ... class ModuleConfig: + ... a: int = 0 + ... b: int = 1 + ... mode: str = "proj" + >>> @dataclass + ... class Config: + ... module: ModuleConfig = field(default_factory=ModuleConfig) + >>> config = Config() + >>> module = registry.build(config) + >>> type(module) + + >>> module.a, module.b + (0, 1) + >>> type(module) + + """ + + key = self.key + config_ = deepcopy(config) + + while "." in key: + key, rest = key.split(".", 1) + config_, key = getattr(config_, key), rest + name = getattr(config_, key) + + return self.init(self.lookup(name), config) # type: ignore[arg-type] + + GlobalRegistry = Registry() diff --git a/docs/docs/registry.md b/docs/docs/registry.md index d1432a71..ed728574 100644 --- a/docs/docs/registry.md +++ b/docs/docs/registry.md @@ -6,4 +6,4 @@ date: 2022-05-04 # Registry -::: chanfig.Registry +::: chanfig.registry