Skip to content

Commit

Permalink
add ConfigRegistry
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Mar 28, 2024
1 parent a3a4fde commit 2ccb854
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 5 deletions.
7 changes: 6 additions & 1 deletion chanfig/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .functional import load, save
from .nested_dict import NestedDict, apply, apply_
from .parser import ConfigParser
from .registry import GlobalRegistry, Registry
from .registry import ConfigRegistry, GlobalRegistry, Registry
from .variable import Variable

__all__ = [
Expand All @@ -37,6 +37,7 @@
"NestedDict",
"FlatDict",
"Registry",
"ConfigRegistry",
"GlobalRegistry",
"DefaultDict",
"ConfigParser",
Expand All @@ -56,7 +57,11 @@
add_representer(NestedDict, SafeRepresenter.represent_dict)
add_representer(DefaultDict, SafeRepresenter.represent_dict)
add_representer(Config, SafeRepresenter.represent_dict)
add_representer(Registry, SafeRepresenter.represent_dict)
add_representer(ConfigRegistry, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(FlatDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(NestedDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(DefaultDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(Config, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(Registry, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(ConfigRegistry, SafeRepresenter.represent_dict)
6 changes: 4 additions & 2 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def post(self) -> Self | None:
Note that you should always call `boot` to apply `post` rather than calling `post` directly,
as `boot` recursively call `post` on sub-configs.
See Also: [`boot`][chanfig.Config.boot]
See Also:
[`boot`][chanfig.Config.boot]
Returns:
self:
Expand Down Expand Up @@ -199,7 +200,8 @@ def boot(self) -> Self:
By default, `boot` is called after `Config` is parsed.
If you don't need to parse command-line arguments, you should call `boot` manually.
See Also: [`post`][chanfig.Config.post]
See Also:
[`post`][chanfig.Config.post]
Returns:
self:
Expand Down
126 changes: 125 additions & 1 deletion chanfig/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Registry(NestedDict):
"""
`Registry` for components.
Registry provides 3 core functionalities:
`Registry` provides 3 core functionalities:
- Register a new component.
- Lookup for a component.
Expand Down Expand Up @@ -55,6 +55,9 @@ class Registry(NestedDict):
You could create a sub-registry by simply calling `registry.sub_registry = Registry`,
and access through `registry.sub_registry.register()`.
See Also:
[`ConfigRegistry`][chanfig.ConfigRegistry]: Optimised for components that can be initialised with a `config`.
Examples:
>>> registry = Registry()
>>> @registry.register
Expand Down Expand Up @@ -254,4 +257,125 @@ def build(self, name: str | MutableMapping | None = None, *args: Any, **kwargs:
return self.init(self.lookup(name), *args, **kwargs) # type: ignore[arg-type]


class ConfigRegistry(Registry):
"""
`ConfigRegistry` for components that can be initialised with a `config`.
`ConfigRegistry` is purcutularly useful when you want to construct a component from a configuration, such as a
Hugginface Transformers model.
See Also:
[`Registry`][chanfig.Registry]: General purpose Registry.
Examples:
>>> from dataclasses import dataclass, field
>>> @dataclass
... class Config:
... a: int
... b: int
... mode: str = "proj"
>>> registry = ConfigRegistry(key="mode")
>>> @registry.register("proj")
... class Proj:
... def __init__(self, config):
... self.a = config.a
... self.b = config.b
>>> @registry.register("inv")
... class Inv:
... def __init__(self, config):
... self.a = config.b
... self.b = config.a
>>> registry
ConfigRegistry(
('proj'): <class 'chanfig.registry.Proj'>
('inv'): <class 'chanfig.registry.Inv'>
)
>>> config = Config(a=0, b=1)
>>> module = registry.build(config)
>>> module.a, module.b
(0, 1)
>>> config = Config(a=0, b=1, mode="inv")
>>> module = registry.build(config)
>>> module.a, module.b
(1, 0)
>>> @dataclass
... class ModuleConfig:
... a: int = 0
... b: int = 1
... mode: str = "proj"
>>> @dataclass
... class NestedConfig:
... module: ModuleConfig = field(default_factory=ModuleConfig)
>>> nested_registry = ConfigRegistry(key="module.mode")
>>> @nested_registry.register("proj")
... class Proj:
... def __init__(self, config):
... self.a = config.module.a
... self.b = config.module.b
>>> @nested_registry.register("inv")
... class Inv:
... def __init__(self, config):
... self.a = config.module.b
... self.b = config.module.a
>>> nested_config = NestedConfig()
>>> module = nested_registry.build(nested_config)
>>> module.a, module.b
(0, 1)
"""

def build(self, config) -> Any: # type: ignore[override]
r"""
Build a component.
Args:
config
Returns:
(Any):
Raises:
KeyError: If the component is not registered.
Examples:
>>> from dataclasses import dataclass, field
>>> registry = ConfigRegistry(key="module.mode")
>>> @registry.register("proj")
... class Proj:
... def __init__(self, config):
... self.a = config.module.a
... self.b = config.module.b
>>> @registry.register("inv")
... class Inv:
... def __init__(self, config):
... self.a = config.module.b
... self.b = config.module.a
>>> @dataclass
... class ModuleConfig:
... a: int = 0
... b: int = 1
... mode: str = "proj"
>>> @dataclass
... class Config:
... module: ModuleConfig = field(default_factory=ModuleConfig)
>>> config = Config()
>>> module = registry.build(config)
>>> type(module)
<class 'chanfig.registry.Proj'>
>>> module.a, module.b
(0, 1)
>>> type(module)
<class 'chanfig.registry.Proj'>
"""

key = self.key
config_ = deepcopy(config)

while "." in key:
key, rest = key.split(".", 1)
config_, key = getattr(config_, key), rest
name = getattr(config_, key)

return self.init(self.lookup(name), config) # type: ignore[arg-type]


GlobalRegistry = Registry()
2 changes: 1 addition & 1 deletion docs/docs/registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ date: 2022-05-04

# Registry

::: chanfig.Registry
::: chanfig.registry

0 comments on commit 2ccb854

Please sign in to comment.