Skip to content

Commit

Permalink
support specify key for Registry.build
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Mar 28, 2024
1 parent 8f988e5 commit a3a4fde
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions chanfig/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from collections.abc import Callable, Mapping
from collections.abc import Callable, MutableMapping
from copy import deepcopy
from functools import wraps
from typing import Any
Expand Down Expand Up @@ -88,8 +88,9 @@ class Registry(NestedDict):

override: bool = False

def __init__(self, override: bool | None = None, fallback: bool | None = None):
def __init__(self, override: bool | None = None, key: str = "name", fallback: bool | None = None):
super().__init__(fallback=fallback)
self.setattr("key", key)
if override is not None:
self.setattr("override", override)

Expand Down Expand Up @@ -207,14 +208,14 @@ def init(cls: Callable, *args: Any, **kwargs: Any) -> Any: # pylint: disable=W0

return cls(*args, **kwargs)

def build(self, name: str | Mapping, *args: Any, **kwargs: Any) -> Any:
def build(self, name: str | MutableMapping | None = None, *args: Any, **kwargs: Any) -> Any:
r"""
Build a component.
Args:
name (str | Mapping):
If its a `Mapping`, it must contain `"name"` as a member, the rest will be treated as `**kwargs`.
Note that values in `kwargs` will override values in `name` if its a `Mapping`.
name (str | MutableMapping):
If its a `MutableMapping`, it must contain `key` as a member, the rest will be treated as `**kwargs`.
Note that values in `kwargs` will override values in `name` if its a `MutableMapping`.
*args: The arguments to pass to the component.
**kwargs: The keyword arguments to pass to the component.
Expand All @@ -225,13 +226,13 @@ def build(self, name: str | Mapping, *args: Any, **kwargs: Any) -> Any:
KeyError: If the component is not registered.
Examples:
>>> registry = Registry()
>>> registry = Registry(key="model")
>>> @registry.register
... class Module:
... def __init__(self, a, b):
... self.a = a
... self.b = b
>>> config = {"module": {"name": "Module", "a": 1, "b": 2}}
>>> config = {"module": {"model": "Module", "a": 1, "b": 2}}
>>> # registry.register(Module)
>>> module = registry.build(**config["module"])
>>> type(module)
Expand All @@ -245,9 +246,11 @@ def build(self, name: str | Mapping, *args: Any, **kwargs: Any) -> Any:
2
"""

if isinstance(name, Mapping):
if isinstance(name, MutableMapping):
name = deepcopy(name)
name, kwargs = name.pop("name"), dict(name, **kwargs) # type: ignore[attr-defined, arg-type]
name, kwargs = name.pop(self.getattr("key", "name")), dict(name, **kwargs) # type: ignore[arg-type]
if name is None:
name, kwargs = kwargs.pop(self.getattr("key")), dict(**kwargs)
return self.init(self.lookup(name), *args, **kwargs) # type: ignore[arg-type]


Expand Down

0 comments on commit a3a4fde

Please sign in to comment.