diff --git a/docs/sphinx_doc/en/source/tutorial/203-model.md b/docs/sphinx_doc/en/source/tutorial/203-model.md index 2aad86e1e..1a69b0c49 100644 --- a/docs/sphinx_doc/en/source/tutorial/203-model.md +++ b/docs/sphinx_doc/en/source/tutorial/203-model.md @@ -90,7 +90,7 @@ In the current AgentScope, the supported `model_type` types, the corresponding | | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... | | LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - | | Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... | -| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - | +| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | - | - | | | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... | | | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | | | | Embedding | [`PostAPIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_embedding` | - | @@ -519,7 +519,7 @@ com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py">agen ```python { "config_name": "my_postapiwrapper_config", - "model_type": "post_api", + "model_type": "post_api_chat", # Required parameters "api_url": "https://xxx.xxx", @@ -566,6 +566,7 @@ The new model wrapper class should - inherit from `ModelWrapperBase` class, - provide a `model_type` field to identify this model wrapper in the model configuration, and - implement its `__init__` and `__call__` functions. +- register the new model wrapper class by calling `agentscope.register_model_wrapper_class` function The following is an example for creating a new model wrapper class. @@ -586,10 +587,13 @@ class MyModelWrapper(ModelWrapperBase): # ... ``` -After creating the new model wrapper class, the model wrapper will be registered into AgentScope automatically. -You can use it in the model configuration directly. +Then we register the new model wrapper class and use it in the model configuration. ```python +import agentscope + +agentscope.register_model_wrapper_class(MyModelWrapper) + my_model_config = { # Basic parameters "config_name": "my_model_config", diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md b/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md index dda8afe22..43eecc7fe 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md @@ -110,7 +110,7 @@ API如下: | | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... | | LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - | | Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... | -| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - | +| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | - | - | | | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... | | | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | | | | Embedding | [`PostAPIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_embedding` | - | @@ -540,7 +540,7 @@ com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py">agen ```python { "config_name": "my_postapiwrapper_config", - "model_type": "post_api", + "model_type": "post_api_chat", # 必要参数 "api_url": "https://xxx.xxx", @@ -586,6 +586,7 @@ AgentScope允许开发者自定义自己的模型包装器。新的模型包装 - 继承自`ModelWrapperBase`类, - 提供`model_type`字段以在模型配置中标识这个Model Wrapper类,并 - 实现`__init__`和`__call__`函数。 +- 调用`agentscope.register_model_wrapper_class`函数,将其注册到AgentScope中。 ```python from agentscope.models import ModelWrapperBase @@ -604,10 +605,13 @@ class MyModelWrapper(ModelWrapperBase): # ... ``` -在创建新的模型包装器类之后,模型包装器将自动注册到AgentScope中。 -您可以直接在模型配置中使用它。 +然后调用`register_model_wrapper_class`函数将其注册到AgentScope中。 ```python +import agentscope + +agentscope.register_model_wrapper_class(MyModelWrapper) + my_model_config = { # 基础参数 "config_name": "my_model_config", diff --git a/src/agentscope/__init__.py b/src/agentscope/__init__.py index 2b48936a9..8a7bb667f 100644 --- a/src/agentscope/__init__.py +++ b/src/agentscope/__init__.py @@ -22,10 +22,12 @@ from ._init import init from ._init import print_llm_usage from ._init import state_dict +from ._init import register_model_wrapper_class __all__ = [ "init", "state_dict", "print_llm_usage", "msghub", + "register_model_wrapper_class", ] diff --git a/src/agentscope/_init.py b/src/agentscope/_init.py index ba33152a7..42631b227 100644 --- a/src/agentscope/_init.py +++ b/src/agentscope/_init.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- """The init function for the package.""" import json -from typing import Optional, Union, Sequence +from typing import Optional, Union, Sequence, Type + +from loguru import logger + from agentscope import agents from .agents import AgentBase from .logging import LOG_LEVEL from .constants import _DEFAULT_SAVE_DIR from .constants import _DEFAULT_LOG_LEVEL from .constants import _DEFAULT_CACHE_DIR -from .manager import ASManager +from .manager import ASManager, ModelManager +from .models import ModelWrapperBase # init the singleton class by default settings to avoid reinit in subprocess # especially in spawn mode, which will copy the object from the parent process @@ -122,3 +126,27 @@ def state_dict() -> dict: def print_llm_usage() -> dict: """Print the usage of LLM.""" return ASManager.get_instance().monitor.print_llm_usage() + + +def register_model_wrapper_class( + model_wrapper_class: Type[ModelWrapperBase], + exist_ok: bool = False, +) -> None: + """Register the model wrapper in AgentScope so that you can use it with + model configurations. + + Args: + model_wrapper_class (`Type[ModelWrapperBase]`): + The model wrapper class to be registered, which must inherit from + `ModelWrapperBase`. + exist_ok (`bool`, defaults to `False`): + Whether to overwrite the existing model wrapper with the same name. + """ + logger.info( + f"Registering model wrapper class `{model_wrapper_class.__name__}`.", + ) + ModelManager.get_instance().register_model_wrapper_class( + model_wrapper_class, + exist_ok, + ) + logger.info("Done.") diff --git a/src/agentscope/manager/_model.py b/src/agentscope/manager/_model.py index 0f63f14be..ab63d4960 100644 --- a/src/agentscope/manager/_model.py +++ b/src/agentscope/manager/_model.py @@ -1,21 +1,26 @@ # -*- coding: utf-8 -*- """The model manager for AgentScope.""" +import importlib import json -from typing import Any, Union, Sequence +import os +from typing import Any, Union, Type from loguru import logger -from ..models import ModelWrapperBase, _get_model_wrapper +from ..models import ModelWrapperBase, _BUILD_IN_MODEL_WRAPPERS class ModelManager: """The model manager for AgentScope, which is responsible for loading and managing model configurations and models.""" + _instance = None + model_configs: dict[str, dict] = {} """The model configs""" - _instance = None + model_wrapper_mapping: dict[str, Type[ModelWrapperBase]] = {} + """The registered model wrapper classes.""" def __new__(cls, *args: Any, **kwargs: Any) -> Any: """Create a singleton instance.""" @@ -43,6 +48,13 @@ def __init__( ) -> None: """Initialize the model manager with model configs""" self.model_configs = {} + self.model_wrapper_mapping = {} + + for cls_name in _BUILD_IN_MODEL_WRAPPERS: + models_module = importlib.import_module("agentscope.models") + cls = getattr(models_module, cls_name) + if getattr(cls, "model_type", None): + self.register_model_wrapper_class(cls, exist_ok=False) def initialize( self, @@ -77,39 +89,50 @@ def load_model_configs( if clear_existing: self.clear_model_configs() - cfgs = None + cfgs = model_configs - if isinstance(model_configs, str): - with open(model_configs, "r", encoding="utf-8") as f: + # Load model configs from a path + if isinstance(cfgs, str): + if not os.path.exists(cfgs): + raise FileNotFoundError( + f"Cannot find the model configs file in the given path " + f"`{model_configs}`.", + ) + with open(cfgs, "r", encoding="utf-8") as f: cfgs = json.load(f) - if isinstance(model_configs, dict): - cfgs = [model_configs] + # Load model configs from a dict or a list of dicts + if isinstance(cfgs, dict): + cfgs = [cfgs] - if isinstance(model_configs, list): - if not all(isinstance(_, dict) for _ in model_configs): + if isinstance(cfgs, list): + if not all(isinstance(_, dict) for _ in cfgs): raise ValueError( "The model config unit should be a dict.", ) - cfgs = model_configs - - if cfgs is None: + else: raise TypeError( f"Invalid type of model_configs, it could be a dict, a list " f"of dicts, or a path to a json file (containing a dict or a " f"list of dicts), but got {type(model_configs)}", ) - formatted_configs = _format_configs(configs=cfgs) + # Check and register the model configs + for cfg in cfgs: + # Check the format of model configs + if "config_name" not in cfg or "model_type" not in cfg: + raise ValueError( + "The `config_name` and `model_type` fields are required " + f"for model config, but got: {cfg}", + ) - # check if name is unique - for cfg in formatted_configs: - if cfg["config_name"] in self.model_configs: + config_name = cfg["config_name"] + if config_name in self.model_configs: logger.warning( - f"config_name [{cfg['config_name']}] already exists.", + f"Config name [{config_name}] already exists.", ) continue - self.model_configs[cfg["config_name"]] = cfg + self.model_configs[config_name] = cfg # print the loaded model configs logger.info( @@ -138,10 +161,16 @@ def get_model_by_config_name(self, config_name: str) -> ModelWrapperBase: ) model_type = config["model_type"] + if model_type not in self.model_wrapper_mapping: + raise ValueError( + f"Unsupported model_type `{model_type}`, currently supported " + f"model types: " + f"{', '.join(list(self.model_wrapper_mapping.keys()))}. ", + ) kwargs = {k: v for k, v in config.items() if k != "model_type"} - return _get_model_wrapper(model_type=model_type)(**kwargs) + return self.model_wrapper_mapping[model_type](**kwargs) def get_config_by_name(self, config_name: str) -> Union[dict, None]: """Load the model config by name, and return the config dict.""" @@ -159,33 +188,50 @@ def load_dict(self, data: dict) -> None: assert "model_configs" in data self.model_configs = data["model_configs"] - def flush(self) -> None: - """Flush the model manager.""" - self.clear_model_configs() - + def register_model_wrapper_class( + self, + model_wrapper_class: Type[ModelWrapperBase], + exist_ok: bool, + ) -> None: + """Register the model wrapper class. -def _format_configs( - configs: Union[Sequence[dict], dict], -) -> Sequence: - """Check the format of model configs. + Args: + model_wrapper_class (`Type[ModelWrapperBase]`): + The model wrapper class to be registered, which must inherit + from `ModelWrapperBase`. + exist_ok (`bool`): + Whether to overwrite the existing model wrapper with the same + name. + """ - Args: - configs (Union[Sequence[dict], dict]): configs in dict format. + if not issubclass(model_wrapper_class, ModelWrapperBase): + raise TypeError( + "The model wrapper class should inherit from " + f"ModelWrapperBase, but got {model_wrapper_class}.", + ) - Returns: - Sequence[dict]: converted ModelConfig list. - """ - if isinstance(configs, dict): - configs = [configs] - for config in configs: - if "config_name" not in config: + if not hasattr(model_wrapper_class, "model_type"): raise ValueError( - "The `config_name` field is required for Cfg", - ) - if "model_type" not in config: - logger.warning( - "`model_type` is not provided in config" - f"[{config['config_name']}]," - " use `PostAPIModelWrapperBase` by default.", + f"The model wrapper class `{model_wrapper_class}` should " + f"have a `model_type` attribute.", ) - return configs + + model_type = model_wrapper_class.model_type + if model_type in self.model_wrapper_mapping: + if exist_ok: + logger.warning( + f'Model wrapper "{model_type}" ' + "already exists, overwrite it.", + ) + self.model_wrapper_mapping[model_type] = model_wrapper_class + else: + raise ValueError( + f'Model wrapper "{model_type}" already exists, ' + "please set `exist_ok=True` to overwrite it.", + ) + else: + self.model_wrapper_mapping[model_type] = model_wrapper_class + + def flush(self) -> None: + """Flush the model manager.""" + self.clear_model_configs() diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py index 0a6894b35..60de46f44 100644 --- a/src/agentscope/models/__init__.py +++ b/src/agentscope/models/__init__.py @@ -1,8 +1,5 @@ # -*- coding: utf-8 -*- """ Import modules in models package.""" -from typing import Type - -from loguru import logger from .model import ModelWrapperBase from .response import ModelResponse @@ -42,6 +39,26 @@ YiChatWrapper, ) +_BUILD_IN_MODEL_WRAPPERS = [ + "PostAPIChatWrapper", + "OpenAIChatWrapper", + "OpenAIDALLEWrapper", + "OpenAIEmbeddingWrapper", + "DashScopeChatWrapper", + "DashScopeImageSynthesisWrapper", + "DashScopeTextEmbeddingWrapper", + "DashScopeMultiModalWrapper", + "OllamaChatWrapper", + "OllamaEmbeddingWrapper", + "OllamaGenerationWrapper", + "GeminiChatWrapper", + "GeminiEmbeddingWrapper", + "ZhipuAIChatWrapper", + "ZhipuAIEmbeddingWrapper", + "LiteLLMChatWrapper", + "YiChatWrapper", +] + __all__ = [ "ModelWrapperBase", "ModelResponse", @@ -65,22 +82,3 @@ "LiteLLMChatWrapper", "YiChatWrapper", ] - - -def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]: - """Get the specific type of model wrapper - - Args: - model_type (`str`): The model type name. - - Returns: - `Type[ModelWrapperBase]`: The corresponding model wrapper class. - """ - wrapper = ModelWrapperBase.get_wrapper(model_type=model_type) - if wrapper is None: - logger.warning( - f"Unsupported model_type [{model_type}]," - "use PostApiModelWrapper instead.", - ) - return PostAPIModelWrapperBase - return wrapper diff --git a/src/agentscope/models/dashscope_model.py b/src/agentscope/models/dashscope_model.py index 6192a5d31..490883671 100644 --- a/src/agentscope/models/dashscope_model.py +++ b/src/agentscope/models/dashscope_model.py @@ -119,8 +119,6 @@ class DashScopeChatWrapper(DashScopeWrapperBase): model_type: str = "dashscope_chat" - deprecated_model_type: str = "tongyi_chat" - def __init__( self, config_name: str, diff --git a/src/agentscope/models/model.py b/src/agentscope/models/model.py index 5ace8c161..0f2cca403 100644 --- a/src/agentscope/models/model.py +++ b/src/agentscope/models/model.py @@ -1,64 +1,11 @@ # -*- coding: utf-8 -*- -"""The configuration file should contain one or a list of model configs, -and each model config should follow the following format. +"""The model wrapper base class.""" -.. code-block:: python - - { - "config_name": "{config_name}", - "model_type": "openai_chat" | "post_api" | ..., - ... - } - -After that, you can specify model by {config_name}. - -Note: - The parameters for different types of models are different. For OpenAI API, - the format is: - - .. code-block:: python - - { - "config_name": "{id of your model}", - "model_type": "openai_chat", - "model_name": "{model_name_for_openai, e.g. gpt-3.5-turbo}", - "api_key": "{your_api_key}", - "organization": "{your_organization, if needed}", - "client_args": { - # ... - }, - "generate_args": { - # ... - } - } - - - For Post API, toking huggingface inference API as an example, its format - is: - - .. code-block:: python - - { - "config_name": "{config_name}", - "model_type": "post_api", - "api_url": "{api_url}", - "headers": {"Authorization": "Bearer {API_TOKEN}"}, - "max_length": {max_length_of_model}, - "timeout": {timeout}, - "max_retries": {max_retries}, - "generate_args": { - "temperature": 0.5, - # ... - } - } - -""" from __future__ import annotations import inspect import time -from abc import ABCMeta from functools import wraps -from typing import Sequence, Any, Callable, Union, List, Type +from typing import Sequence, Any, Callable, Union, List from loguru import logger @@ -140,32 +87,7 @@ def checking_wrapper(self: Any, *args: Any, **kwargs: Any) -> dict: return checking_wrapper -class _ModelWrapperMeta(ABCMeta): - """A meta call to replace the model wrapper's __call__ function with - wrapper about error handling.""" - - def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any: - if "__call__" in attrs: - attrs["__call__"] = _response_parse_decorator(attrs["__call__"]) - return super().__new__(mcs, name, bases, attrs) - - def __init__(cls, name: Any, bases: Any, attrs: Any) -> None: - if not hasattr(cls, "_registry"): - cls._registry = {} - cls._type_registry = {} - cls._deprecated_type_registry = {} - else: - cls._registry[name] = cls - if hasattr(cls, "model_type"): - cls._type_registry[cls.model_type] = cls - if hasattr(cls, "deprecated_model_type"): - cls._deprecated_type_registry[ - cls.deprecated_model_type - ] = cls - super().__init__(name, bases, attrs) - - -class ModelWrapperBase(metaclass=_ModelWrapperMeta): +class ModelWrapperBase: """The base class for model wrapper.""" model_type: str @@ -202,23 +124,6 @@ def __init__( self.model_name = model_name logger.info(f"Initialize model by configuration [{config_name}]") - @classmethod - def get_wrapper(cls, model_type: str) -> Type[ModelWrapperBase]: - """Get the specific model wrapper""" - if model_type in cls._type_registry: - return cls._type_registry[model_type] # type: ignore[return-value] - elif model_type in cls._registry: - return cls._registry[model_type] # type: ignore[return-value] - elif model_type in cls._deprecated_type_registry: - deprecated_cls = cls._deprecated_type_registry[model_type] - logger.warning( - f"Model type [{model_type}] will be deprecated in future " - f"releases, please use [{deprecated_cls.model_type}] instead.", - ) - return deprecated_cls # type: ignore[return-value] - else: - return None # type: ignore[return-value] - def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse: """Processing input with the model.""" raise NotImplementedError( diff --git a/src/agentscope/models/openai_model.py b/src/agentscope/models/openai_model.py index 7d7ccc081..ec6032c62 100644 --- a/src/agentscope/models/openai_model.py +++ b/src/agentscope/models/openai_model.py @@ -135,8 +135,6 @@ class OpenAIChatWrapper(OpenAIWrapperBase): model_type: str = "openai_chat" - deprecated_model_type: str = "openai" - substrings_in_vision_models_names = ["gpt-4-turbo", "vision", "gpt-4o"] """The substrings in the model names of vision models.""" diff --git a/src/agentscope/models/post_model.py b/src/agentscope/models/post_model.py index 626673c97..48797627c 100644 --- a/src/agentscope/models/post_model.py +++ b/src/agentscope/models/post_model.py @@ -19,7 +19,7 @@ class PostAPIModelWrapperBase(ModelWrapperBase, ABC): """The base model wrapper for the model deployed on the POST API.""" - model_type: str = "post_api" + model_type: str def __init__( self, @@ -236,8 +236,6 @@ class PostAPIDALLEWrapper(PostAPIModelWrapperBase): model_type: str = "post_api_dall_e" - deprecated_model_type: str = "post_api_dalle" - def _parse_response(self, response: dict) -> ModelResponse: if "data" not in response["data"]["response"]: if "error" in response["data"]["response"]: diff --git a/tests/model_test.py b/tests/model_test.py index 992aa528e..2ec1d8e14 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -10,15 +10,31 @@ from agentscope.models import ( ModelResponse, ModelWrapperBase, + YiChatWrapper, + LiteLLMChatWrapper, + ZhipuAIEmbeddingWrapper, + ZhipuAIChatWrapper, + GeminiEmbeddingWrapper, + GeminiChatWrapper, + OllamaGenerationWrapper, + OllamaEmbeddingWrapper, + OllamaChatWrapper, + DashScopeMultiModalWrapper, + DashScopeTextEmbeddingWrapper, + DashScopeChatWrapper, + DashScopeImageSynthesisWrapper, + OpenAIEmbeddingWrapper, + OpenAIDALLEWrapper, OpenAIChatWrapper, - PostAPIModelWrapperBase, - _get_model_wrapper, + PostAPIChatWrapper, ) class TestModelWrapperSimple(ModelWrapperBase): """A simple model wrapper class for test usage""" + model_type: str = "TestModelWrapperSimple" + def __call__(self, *args: Any, **kwargs: Any) -> ModelResponse: return ModelResponse(text=self.config_name) @@ -36,22 +52,30 @@ def setUp(self) -> None: """Init for BasicModelTest""" agentscope.init(disable_saving=True) - def test_model_registry(self) -> None: - """Test the automatic registration mechanism of model wrapper.""" + def test_build_in_model_wrapper_classes(self) -> None: + """Test the build in model wrapper classes.""" # get model wrapper class by class name - self.assertEqual( - _get_model_wrapper(model_type="TestModelWrapperSimple"), - TestModelWrapperSimple, - ) - # get model wrapper class by model type - self.assertEqual( - _get_model_wrapper(model_type="openai_chat"), - OpenAIChatWrapper, - ) - # return PostAPIModelWrapperBase if model_type is not supported - self.assertEqual( - _get_model_wrapper(model_type="unknown_model_wrapper"), - PostAPIModelWrapperBase, + self.assertDictEqual( + ModelManager.get_instance().model_wrapper_mapping, + { + "post_api_chat": PostAPIChatWrapper, + "openai_chat": OpenAIChatWrapper, + "openai_dall_e": OpenAIDALLEWrapper, + "openai_embedding": OpenAIEmbeddingWrapper, + "dashscope_chat": DashScopeChatWrapper, + "dashscope_image_synthesis": DashScopeImageSynthesisWrapper, + "dashscope_text_embedding": DashScopeTextEmbeddingWrapper, + "dashscope_multimodal": DashScopeMultiModalWrapper, + "ollama_chat": OllamaChatWrapper, + "ollama_embedding": OllamaEmbeddingWrapper, + "ollama_generate": OllamaGenerationWrapper, + "gemini_chat": GeminiChatWrapper, + "gemini_embedding": GeminiEmbeddingWrapper, + "zhipuai_chat": ZhipuAIChatWrapper, + "zhipuai_embedding": ZhipuAIEmbeddingWrapper, + "litellm_chat": LiteLLMChatWrapper, + "yi_chat": YiChatWrapper, + }, ) @patch("loguru.logger.warning") @@ -67,7 +91,7 @@ def test_load_model_configs(self, mock_logging: MagicMock) -> None: "generate_args": {"temperature": 0.5}, }, { - "model_type": "post_api", + "model_type": "post_api_chat", "config_name": "my_post_api", "api_url": "https://xxx", "headers": {}, @@ -110,9 +134,12 @@ def test_load_model_configs(self, mock_logging: MagicMock) -> None: clear_existing=False, ) mock_logging.assert_called_once_with( - "config_name [gpt-4] already exists.", + "Config name [gpt-4] already exists.", ) + def test_register_model_wrapper_class(self) -> None: + """Test the model wrapper class registration.""" + model_manager = ModelManager.get_instance() model_manager.load_model_configs( model_configs={ "model_type": "TestModelWrapperSimple", @@ -121,9 +148,21 @@ def test_load_model_configs(self, mock_logging: MagicMock) -> None: "args": {}, }, ) + + # Not registered model wrapper class + self.assertRaises( + ValueError, + model_manager.get_model_by_config_name, + "test_model_wrapper", + ) + + # Register model wrapper class + agentscope.register_model_wrapper_class(TestModelWrapperSimple) + test_model = model_manager.get_model_by_config_name( "test_model_wrapper", ) + response = test_model() self.assertEqual(response.text, "test_model_wrapper") model_manager.clear_model_configs()