Skip to content

Commit

Permalink
add ConfigRegistry
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Mar 28, 2024
1 parent a3a4fde commit de424df
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 5 deletions.
7 changes: 6 additions & 1 deletion chanfig/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .functional import load, save
from .nested_dict import NestedDict, apply, apply_
from .parser import ConfigParser
from .registry import GlobalRegistry, Registry
from .registry import ConfigRegistry, GlobalRegistry, Registry
from .variable import Variable

__all__ = [
Expand All @@ -37,6 +37,7 @@
"NestedDict",
"FlatDict",
"Registry",
"ConfigRegistry",
"GlobalRegistry",
"DefaultDict",
"ConfigParser",
Expand All @@ -56,7 +57,11 @@
add_representer(NestedDict, SafeRepresenter.represent_dict)
add_representer(DefaultDict, SafeRepresenter.represent_dict)
add_representer(Config, SafeRepresenter.represent_dict)
add_representer(Registry, SafeRepresenter.represent_dict)
add_representer(ConfigRegistry, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(FlatDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(NestedDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(DefaultDict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(Config, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(Registry, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(ConfigRegistry, SafeRepresenter.represent_dict)
6 changes: 4 additions & 2 deletions chanfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def post(self) -> Self | None:
Note that you should always call `boot` to apply `post` rather than calling `post` directly,
as `boot` recursively call `post` on sub-configs.
See Also: [`boot`][chanfig.Config.boot]
See Also:
[`boot`][chanfig.Config.boot]
Returns:
self:
Expand Down Expand Up @@ -199,7 +200,8 @@ def boot(self) -> Self:
By default, `boot` is called after `Config` is parsed.
If you don't need to parse command-line arguments, you should call `boot` manually.
See Also: [`post`][chanfig.Config.post]
See Also:
[`post`][chanfig.Config.post]
Returns:
self:
Expand Down
91 changes: 90 additions & 1 deletion chanfig/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Registry(NestedDict):
"""
`Registry` for components.
Registry provides 3 core functionalities:
`Registry` provides 3 core functionalities:
- Register a new component.
- Lookup for a component.
Expand Down Expand Up @@ -55,6 +55,9 @@ class Registry(NestedDict):
You could create a sub-registry by simply calling `registry.sub_registry = Registry`,
and access through `registry.sub_registry.register()`.
See Also:
[`ConfigRegistry`][chanfig.ConfigRegistry]: Optimised for components that can be initialised with a `config`.
Examples:
>>> registry = Registry()
>>> @registry.register
Expand Down Expand Up @@ -254,4 +257,90 @@ def build(self, name: str | MutableMapping | None = None, *args: Any, **kwargs:
return self.init(self.lookup(name), *args, **kwargs) # type: ignore[arg-type]


class ConfigRegistry(Registry):
"""
`ConfigRegistry` for components that can be initialised with a `config`.
`ConfigRegistry` is purcutularly useful when you want to construct a component from a configuration, such as a
Hugginface Transformers model.
See Also:
[`Registry`][chanfig.Registry]: General purpose Registry.
Examples:
>>> from dataclasses import dataclass
>>> @dataclass
... class Config:
... a: int
... b: int
... proj_head_mode: str = "proj"
>>> registry = ConfigRegistry(key="proj_head_mode")
>>> @registry.register("proj")
... class Proj:
... def __init__(self, config):
... self.a = config.a
... self.b = config.b
>>> @registry.register("inv")
... class Inv:
... def __init__(self, config):
... self.a = config.b
... self.b = config.a
>>> registry
ConfigRegistry(
('proj'): <class 'chanfig.registry.Proj'>
('inv'): <class 'chanfig.registry.Inv'>
)
>>> config = Config(a=1, b=2)
>>> module = registry.build(config)
>>> type(module)
<class 'chanfig.registry.Proj'>
>>> module.a
1
>>> module.b
2
>>> config = Config(a=1, b=2, proj_head_mode="inv")
>>> module = registry.build(config)
>>> module.a
2
>>> module.b
1
"""

def build(self, config) -> Any: # type: ignore[override]
r"""
Build a component.
Args:
config
Returns:
(Any):
Raises:
KeyError: If the component is not registered.
Examples:
>>> registry = Registry(key="model")
>>> @registry.register
... class Module:
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> config = {"module": {"model": "Module", "a": 1, "b": 2}}
>>> # registry.register(Module)
>>> module = registry.build(**config["module"])
>>> type(module)
<class 'chanfig.registry.Module'>
>>> module.a
1
>>> module.b
2
>>> module = registry.build(config["module"], a=2)
>>> module.a
2
"""

return self.init(self.lookup(getattr(config, self.key)), config) # type: ignore[arg-type]


GlobalRegistry = Registry()
2 changes: 1 addition & 1 deletion docs/docs/registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ date: 2022-05-04

# Registry

::: chanfig.Registry
::: chanfig.registry

0 comments on commit de424df

Please sign in to comment.