Skip to content

Commit

Permalink
register model wrapper using alias
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c committed Feb 5, 2024
1 parent 9d772ea commit 8536226
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
14 changes: 4 additions & 10 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@

_MODEL_CONFIGS: dict[str, dict] = {}

_MODEL_MAP: dict[str, Type[ModelWrapperBase]] = {
"openai": OpenAIChatWrapper,
"openai_dall_e": OpenAIDALLEWrapper,
"openai_embedding": OpenAIEmbeddingWrapper,
"post_api": PostAPIModelWrapperBase,
"post_api_chat": PostAPIChatWrapper,
}


def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]:
"""Get the specific type of model wrapper
Expand All @@ -52,8 +44,10 @@ def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]:
Returns:
`Type[ModelWrapperBase]`: The corresponding model wrapper class.
"""
if model_type in _MODEL_MAP:
return _MODEL_MAP[model_type]
if model_type in ModelWrapperBase.alias_registry:
return ModelWrapperBase.alias_registry[ # type: ignore [return-value]
model_type
]
elif model_type in ModelWrapperBase.registry:
return ModelWrapperBase.registry[ # type: ignore [return-value]
model_type
Expand Down
3 changes: 3 additions & 0 deletions src/agentscope/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,11 @@ def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any:
def __init__(cls, name: Any, bases: Any, attrs: Any) -> None:
if not hasattr(cls, "registry"):
cls.registry = {}
cls.alias_registry = {}
else:
cls.registry[name] = cls
if hasattr(cls, "alias"):
cls.alias_registry[cls.alias] = cls
super().__init__(name, bases, attrs)


Expand Down
6 changes: 6 additions & 0 deletions src/agentscope/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def _metric(self, metric_name: str) -> str:
class OpenAIChatWrapper(OpenAIWrapper):
"""The model wrapper for OpenAI's chat API."""

alias: str = "openai"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
# TODO: set quota to the following metrics
Expand Down Expand Up @@ -234,6 +236,8 @@ def __call__(
class OpenAIDALLEWrapper(OpenAIWrapper):
"""The model wrapper for OpenAI's DALL·E API."""

alias: str = "openai_dall_e"

_resolutions: list = [
"1792*1024",
"1024*1792",
Expand Down Expand Up @@ -330,6 +334,8 @@ def __call__(
class OpenAIEmbeddingWrapper(OpenAIWrapper):
"""The model wrapper for OpenAI embedding API."""

alias: str = "openai_embedding"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
# TODO: set quota to the following metrics
Expand Down
6 changes: 6 additions & 0 deletions src/agentscope/models/post_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
class PostAPIModelWrapperBase(ModelWrapperBase):
"""The base model wrapper for the model deployed on the POST API."""

alias: str = "post_api"

def __init__(
self,
model_id: str,
Expand Down Expand Up @@ -172,6 +174,8 @@ class PostAPIChatWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatilble with openai chat, e.g., vLLM,
FastChat."""

alias: str = "post_api_chat"

def _parse_response(self, response: dict) -> ModelResponse:
return ModelResponse(
text=response["data"]["response"]["choices"][0]["message"][
Expand All @@ -183,6 +187,8 @@ def _parse_response(self, response: dict) -> ModelResponse:
class PostAPIDALLEWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatible with openai dalle"""

alias: str = "post_api_dalle"

def _parse_response(self, response: dict) -> ModelResponse:
urls = [img["url"] for img in response["data"]["response"]["data"]]
return ModelResponse(image_urls=urls)

0 comments on commit 8536226

Please sign in to comment.