From 853622680f43cbfc0add50ec6953242a9e694456 Mon Sep 17 00:00:00 2001 From: "panxuchen.pxc" Date: Mon, 5 Feb 2024 14:52:03 +0800 Subject: [PATCH] register model wrapper using alias --- src/agentscope/models/__init__.py | 14 ++++---------- src/agentscope/models/model.py | 3 +++ src/agentscope/models/openai_model.py | 6 ++++++ src/agentscope/models/post_model.py | 6 ++++++ 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py index dff60ce89..72b8e21d9 100644 --- a/src/agentscope/models/__init__.py +++ b/src/agentscope/models/__init__.py @@ -34,14 +34,6 @@ _MODEL_CONFIGS: dict[str, dict] = {} -_MODEL_MAP: dict[str, Type[ModelWrapperBase]] = { - "openai": OpenAIChatWrapper, - "openai_dall_e": OpenAIDALLEWrapper, - "openai_embedding": OpenAIEmbeddingWrapper, - "post_api": PostAPIModelWrapperBase, - "post_api_chat": PostAPIChatWrapper, -} - def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]: """Get the specific type of model wrapper @@ -52,8 +44,10 @@ def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]: Returns: `Type[ModelWrapperBase]`: The corresponding model wrapper class. """ - if model_type in _MODEL_MAP: - return _MODEL_MAP[model_type] + if model_type in ModelWrapperBase.alias_registry: + return ModelWrapperBase.alias_registry[ # type: ignore [return-value] + model_type + ] elif model_type in ModelWrapperBase.registry: return ModelWrapperBase.registry[ # type: ignore [return-value] model_type diff --git a/src/agentscope/models/model.py b/src/agentscope/models/model.py index d4600f98b..43d6b5ebd 100644 --- a/src/agentscope/models/model.py +++ b/src/agentscope/models/model.py @@ -190,8 +190,11 @@ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any: def __init__(cls, name: Any, bases: Any, attrs: Any) -> None: if not hasattr(cls, "registry"): cls.registry = {} + cls.alias_registry = {} else: cls.registry[name] = cls + if hasattr(cls, "alias"): + cls.alias_registry[cls.alias] = cls super().__init__(name, bases, attrs) diff --git a/src/agentscope/models/openai_model.py b/src/agentscope/models/openai_model.py index 1f6afe6f3..d72db2145 100644 --- a/src/agentscope/models/openai_model.py +++ b/src/agentscope/models/openai_model.py @@ -126,6 +126,8 @@ def _metric(self, metric_name: str) -> str: class OpenAIChatWrapper(OpenAIWrapper): """The model wrapper for OpenAI's chat API.""" + alias: str = "openai" + def _register_default_metrics(self) -> None: # Set monitor accordingly # TODO: set quota to the following metrics @@ -234,6 +236,8 @@ def __call__( class OpenAIDALLEWrapper(OpenAIWrapper): """The model wrapper for OpenAI's DALLĀ·E API.""" + alias: str = "openai_dall_e" + _resolutions: list = [ "1792*1024", "1024*1792", @@ -330,6 +334,8 @@ def __call__( class OpenAIEmbeddingWrapper(OpenAIWrapper): """The model wrapper for OpenAI embedding API.""" + alias: str = "openai_embedding" + def _register_default_metrics(self) -> None: # Set monitor accordingly # TODO: set quota to the following metrics diff --git a/src/agentscope/models/post_model.py b/src/agentscope/models/post_model.py index 95e79f867..4865bc2dd 100644 --- a/src/agentscope/models/post_model.py +++ b/src/agentscope/models/post_model.py @@ -16,6 +16,8 @@ class PostAPIModelWrapperBase(ModelWrapperBase): """The base model wrapper for the model deployed on the POST API.""" + alias: str = "post_api" + def __init__( self, model_id: str, @@ -172,6 +174,8 @@ class PostAPIChatWrapper(PostAPIModelWrapperBase): """A post api model wrapper compatilble with openai chat, e.g., vLLM, FastChat.""" + alias: str = "post_api_chat" + def _parse_response(self, response: dict) -> ModelResponse: return ModelResponse( text=response["data"]["response"]["choices"][0]["message"][ @@ -183,6 +187,8 @@ def _parse_response(self, response: dict) -> ModelResponse: class PostAPIDALLEWrapper(PostAPIModelWrapperBase): """A post api model wrapper compatible with openai dalle""" + alias: str = "post_api_dalle" + def _parse_response(self, response: dict) -> ModelResponse: urls = [img["url"] for img in response["data"]["response"]["data"]] return ModelResponse(image_urls=urls)