From a3a4fdebe00e406f43874fdfed17d47db301ad34 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 28 Mar 2024 18:23:21 +0800 Subject: [PATCH] support specify key for Registry.build Signed-off-by: Zhiyuan Chen --- chanfig/registry.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/chanfig/registry.py b/chanfig/registry.py index 82d54cdc..25be85bc 100644 --- a/chanfig/registry.py +++ b/chanfig/registry.py @@ -15,7 +15,7 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Callable, MutableMapping from copy import deepcopy from functools import wraps from typing import Any @@ -88,8 +88,9 @@ class Registry(NestedDict): override: bool = False - def __init__(self, override: bool | None = None, fallback: bool | None = None): + def __init__(self, override: bool | None = None, key: str = "name", fallback: bool | None = None): super().__init__(fallback=fallback) + self.setattr("key", key) if override is not None: self.setattr("override", override) @@ -207,14 +208,14 @@ def init(cls: Callable, *args: Any, **kwargs: Any) -> Any: # pylint: disable=W0 return cls(*args, **kwargs) - def build(self, name: str | Mapping, *args: Any, **kwargs: Any) -> Any: + def build(self, name: str | MutableMapping | None = None, *args: Any, **kwargs: Any) -> Any: r""" Build a component. Args: - name (str | Mapping): - If its a `Mapping`, it must contain `"name"` as a member, the rest will be treated as `**kwargs`. - Note that values in `kwargs` will override values in `name` if its a `Mapping`. + name (str | MutableMapping): + If its a `MutableMapping`, it must contain `key` as a member, the rest will be treated as `**kwargs`. + Note that values in `kwargs` will override values in `name` if its a `MutableMapping`. *args: The arguments to pass to the component. **kwargs: The keyword arguments to pass to the component. @@ -225,13 +226,13 @@ def build(self, name: str | Mapping, *args: Any, **kwargs: Any) -> Any: KeyError: If the component is not registered. Examples: - >>> registry = Registry() + >>> registry = Registry(key="model") >>> @registry.register ... class Module: ... def __init__(self, a, b): ... self.a = a ... self.b = b - >>> config = {"module": {"name": "Module", "a": 1, "b": 2}} + >>> config = {"module": {"model": "Module", "a": 1, "b": 2}} >>> # registry.register(Module) >>> module = registry.build(**config["module"]) >>> type(module) @@ -245,9 +246,11 @@ def build(self, name: str | Mapping, *args: Any, **kwargs: Any) -> Any: 2 """ - if isinstance(name, Mapping): + if isinstance(name, MutableMapping): name = deepcopy(name) - name, kwargs = name.pop("name"), dict(name, **kwargs) # type: ignore[attr-defined, arg-type] + name, kwargs = name.pop(self.getattr("key", "name")), dict(name, **kwargs) # type: ignore[arg-type] + if name is None: + name, kwargs = kwargs.pop(self.getattr("key")), dict(**kwargs) return self.init(self.lookup(name), *args, **kwargs) # type: ignore[arg-type]