Skip to content

Commit

Permalink
Support explicit model wrapper registration (#504)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavdGao authored Jan 6, 2025
1 parent 4f897f9 commit 36a3943
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 201 deletions.
12 changes: 8 additions & 4 deletions docs/sphinx_doc/en/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ In the current AgentScope, the supported `model_type` types, the corresponding
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | - | - |
| | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
| | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | |
| | Embedding | [`PostAPIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_embedding` | - |
Expand Down Expand Up @@ -519,7 +519,7 @@ com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py">agen
```python
{
"config_name": "my_postapiwrapper_config",
"model_type": "post_api",
"model_type": "post_api_chat",

# Required parameters
"api_url": "https://xxx.xxx",
Expand Down Expand Up @@ -566,6 +566,7 @@ The new model wrapper class should
- inherit from `ModelWrapperBase` class,
- provide a `model_type` field to identify this model wrapper in the model configuration, and
- implement its `__init__` and `__call__` functions.
- register the new model wrapper class by calling `agentscope.register_model_wrapper_class` function

The following is an example for creating a new model wrapper class.

Expand All @@ -586,10 +587,13 @@ class MyModelWrapper(ModelWrapperBase):
# ...
```

After creating the new model wrapper class, the model wrapper will be registered into AgentScope automatically.
You can use it in the model configuration directly.
Then we register the new model wrapper class and use it in the model configuration.

```python
import agentscope

agentscope.register_model_wrapper_class(MyModelWrapper)

my_model_config = {
# Basic parameters
"config_name": "my_model_config",
Expand Down
12 changes: 8 additions & 4 deletions docs/sphinx_doc/zh_CN/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ API如下:
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | - | - |
| | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
| | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | |
| | Embedding | [`PostAPIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_embedding` | - |
Expand Down Expand Up @@ -540,7 +540,7 @@ com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py">agen
```python
{
"config_name": "my_postapiwrapper_config",
"model_type": "post_api",
"model_type": "post_api_chat",

# 必要参数
"api_url": "https://xxx.xxx",
Expand Down Expand Up @@ -586,6 +586,7 @@ AgentScope允许开发者自定义自己的模型包装器。新的模型包装
- 继承自`ModelWrapperBase`类,
- 提供`model_type`字段以在模型配置中标识这个Model Wrapper类,并
- 实现`__init__``__call__`函数。
- 调用`agentscope.register_model_wrapper_class`函数,将其注册到AgentScope中。

```python
from agentscope.models import ModelWrapperBase
Expand All @@ -604,10 +605,13 @@ class MyModelWrapper(ModelWrapperBase):
# ...
```

在创建新的模型包装器类之后,模型包装器将自动注册到AgentScope中。
您可以直接在模型配置中使用它。
然后调用`register_model_wrapper_class`函数将其注册到AgentScope中。

```python
import agentscope

agentscope.register_model_wrapper_class(MyModelWrapper)

my_model_config = {
# 基础参数
"config_name": "my_model_config",
Expand Down
2 changes: 2 additions & 0 deletions src/agentscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
from ._init import init
from ._init import print_llm_usage
from ._init import state_dict
from ._init import register_model_wrapper_class

__all__ = [
"init",
"state_dict",
"print_llm_usage",
"msghub",
"register_model_wrapper_class",
]
32 changes: 30 additions & 2 deletions src/agentscope/_init.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# -*- coding: utf-8 -*-
"""The init function for the package."""
import json
from typing import Optional, Union, Sequence
from typing import Optional, Union, Sequence, Type

from loguru import logger

from agentscope import agents
from .agents import AgentBase
from .logging import LOG_LEVEL
from .constants import _DEFAULT_SAVE_DIR
from .constants import _DEFAULT_LOG_LEVEL
from .constants import _DEFAULT_CACHE_DIR
from .manager import ASManager
from .manager import ASManager, ModelManager
from .models import ModelWrapperBase

# init the singleton class by default settings to avoid reinit in subprocess
# especially in spawn mode, which will copy the object from the parent process
Expand Down Expand Up @@ -122,3 +126,27 @@ def state_dict() -> dict:
def print_llm_usage() -> dict:
"""Print the usage of LLM."""
return ASManager.get_instance().monitor.print_llm_usage()


def register_model_wrapper_class(
model_wrapper_class: Type[ModelWrapperBase],
exist_ok: bool = False,
) -> None:
"""Register the model wrapper in AgentScope so that you can use it with
model configurations.
Args:
model_wrapper_class (`Type[ModelWrapperBase]`):
The model wrapper class to be registered, which must inherit from
`ModelWrapperBase`.
exist_ok (`bool`, defaults to `False`):
Whether to overwrite the existing model wrapper with the same name.
"""
logger.info(
f"Registering model wrapper class `{model_wrapper_class.__name__}`.",
)
ModelManager.get_instance().register_model_wrapper_class(
model_wrapper_class,
exist_ok,
)
logger.info("Done.")
136 changes: 91 additions & 45 deletions src/agentscope/manager/_model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
# -*- coding: utf-8 -*-
"""The model manager for AgentScope."""
import importlib
import json
from typing import Any, Union, Sequence
import os
from typing import Any, Union, Type

from loguru import logger

from ..models import ModelWrapperBase, _get_model_wrapper
from ..models import ModelWrapperBase, _BUILD_IN_MODEL_WRAPPERS


class ModelManager:
"""The model manager for AgentScope, which is responsible for loading and
managing model configurations and models."""

_instance = None

model_configs: dict[str, dict] = {}
"""The model configs"""

_instance = None
model_wrapper_mapping: dict[str, Type[ModelWrapperBase]] = {}
"""The registered model wrapper classes."""

def __new__(cls, *args: Any, **kwargs: Any) -> Any:
"""Create a singleton instance."""
Expand Down Expand Up @@ -43,6 +48,13 @@ def __init__(
) -> None:
"""Initialize the model manager with model configs"""
self.model_configs = {}
self.model_wrapper_mapping = {}

for cls_name in _BUILD_IN_MODEL_WRAPPERS:
models_module = importlib.import_module("agentscope.models")
cls = getattr(models_module, cls_name)
if getattr(cls, "model_type", None):
self.register_model_wrapper_class(cls, exist_ok=False)

def initialize(
self,
Expand Down Expand Up @@ -77,39 +89,50 @@ def load_model_configs(
if clear_existing:
self.clear_model_configs()

cfgs = None
cfgs = model_configs

if isinstance(model_configs, str):
with open(model_configs, "r", encoding="utf-8") as f:
# Load model configs from a path
if isinstance(cfgs, str):
if not os.path.exists(cfgs):
raise FileNotFoundError(
f"Cannot find the model configs file in the given path "
f"`{model_configs}`.",
)
with open(cfgs, "r", encoding="utf-8") as f:
cfgs = json.load(f)

if isinstance(model_configs, dict):
cfgs = [model_configs]
# Load model configs from a dict or a list of dicts
if isinstance(cfgs, dict):
cfgs = [cfgs]

if isinstance(model_configs, list):
if not all(isinstance(_, dict) for _ in model_configs):
if isinstance(cfgs, list):
if not all(isinstance(_, dict) for _ in cfgs):
raise ValueError(
"The model config unit should be a dict.",
)
cfgs = model_configs

if cfgs is None:
else:
raise TypeError(
f"Invalid type of model_configs, it could be a dict, a list "
f"of dicts, or a path to a json file (containing a dict or a "
f"list of dicts), but got {type(model_configs)}",
)

formatted_configs = _format_configs(configs=cfgs)
# Check and register the model configs
for cfg in cfgs:
# Check the format of model configs
if "config_name" not in cfg or "model_type" not in cfg:
raise ValueError(
"The `config_name` and `model_type` fields are required "
f"for model config, but got: {cfg}",
)

# check if name is unique
for cfg in formatted_configs:
if cfg["config_name"] in self.model_configs:
config_name = cfg["config_name"]
if config_name in self.model_configs:
logger.warning(
f"config_name [{cfg['config_name']}] already exists.",
f"Config name [{config_name}] already exists.",
)
continue
self.model_configs[cfg["config_name"]] = cfg
self.model_configs[config_name] = cfg

# print the loaded model configs
logger.info(
Expand Down Expand Up @@ -138,10 +161,16 @@ def get_model_by_config_name(self, config_name: str) -> ModelWrapperBase:
)

model_type = config["model_type"]
if model_type not in self.model_wrapper_mapping:
raise ValueError(
f"Unsupported model_type `{model_type}`, currently supported "
f"model types: "
f"{', '.join(list(self.model_wrapper_mapping.keys()))}. ",
)

kwargs = {k: v for k, v in config.items() if k != "model_type"}

return _get_model_wrapper(model_type=model_type)(**kwargs)
return self.model_wrapper_mapping[model_type](**kwargs)

def get_config_by_name(self, config_name: str) -> Union[dict, None]:
"""Load the model config by name, and return the config dict."""
Expand All @@ -159,33 +188,50 @@ def load_dict(self, data: dict) -> None:
assert "model_configs" in data
self.model_configs = data["model_configs"]

def flush(self) -> None:
"""Flush the model manager."""
self.clear_model_configs()

def register_model_wrapper_class(
self,
model_wrapper_class: Type[ModelWrapperBase],
exist_ok: bool,
) -> None:
"""Register the model wrapper class.
def _format_configs(
configs: Union[Sequence[dict], dict],
) -> Sequence:
"""Check the format of model configs.
Args:
model_wrapper_class (`Type[ModelWrapperBase]`):
The model wrapper class to be registered, which must inherit
from `ModelWrapperBase`.
exist_ok (`bool`):
Whether to overwrite the existing model wrapper with the same
name.
"""

Args:
configs (Union[Sequence[dict], dict]): configs in dict format.
if not issubclass(model_wrapper_class, ModelWrapperBase):
raise TypeError(
"The model wrapper class should inherit from "
f"ModelWrapperBase, but got {model_wrapper_class}.",
)

Returns:
Sequence[dict]: converted ModelConfig list.
"""
if isinstance(configs, dict):
configs = [configs]
for config in configs:
if "config_name" not in config:
if not hasattr(model_wrapper_class, "model_type"):
raise ValueError(
"The `config_name` field is required for Cfg",
)
if "model_type" not in config:
logger.warning(
"`model_type` is not provided in config"
f"[{config['config_name']}],"
" use `PostAPIModelWrapperBase` by default.",
f"The model wrapper class `{model_wrapper_class}` should "
f"have a `model_type` attribute.",
)
return configs

model_type = model_wrapper_class.model_type
if model_type in self.model_wrapper_mapping:
if exist_ok:
logger.warning(
f'Model wrapper "{model_type}" '
"already exists, overwrite it.",
)
self.model_wrapper_mapping[model_type] = model_wrapper_class
else:
raise ValueError(
f'Model wrapper "{model_type}" already exists, '
"please set `exist_ok=True` to overwrite it.",
)
else:
self.model_wrapper_mapping[model_type] = model_wrapper_class

def flush(self) -> None:
"""Flush the model manager."""
self.clear_model_configs()
Loading

0 comments on commit 36a3943

Please sign in to comment.