From 197d59748e82555011c1622d40cbc98cec8cefb0 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Thu, 14 Mar 2024 20:48:29 +0800 Subject: [PATCH] Add ollama based local model services in AgentScope. (#60) * Add ollama api * bug fix * include ollama in __init__.py (agentscope.models package) * bug fix * Add unit test for ollama * bug fix and reformat * add package ollama * reformat * add logging directory management * reformat * re-correct ollama test with single instance * re-correct ollama test with single instance * reformat --- setup.py | 4 +- src/agentscope/models/__init__.py | 13 +- src/agentscope/models/ollama_model.py | 377 ++++++++++++++++++++++++++ tests/ollama_test.py | 168 ++++++++++++ 4 files changed, 560 insertions(+), 2 deletions(-) create mode 100644 src/agentscope/models/ollama_model.py create mode 100644 tests/ollama_test.py diff --git a/setup.py b/setup.py index 2a9e6a99f..29f1fe073 100644 --- a/setup.py +++ b/setup.py @@ -44,12 +44,14 @@ "tiktoken", "Pillow", "requests", - "openai>=1.3.0", "numpy", "Flask==3.0.0", "Flask-Cors==4.0.0", "Flask-SocketIO==5.3.6", + # TODO: move into other requires "dashscope==1.14.1", + "openai>=1.3.0", + "ollama>=0.1.7", ] distribute_requires = minimal_requires + rpc_requires diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py index 2d6dc5fae..3d3da7abc 100644 --- a/src/agentscope/models/__init__.py +++ b/src/agentscope/models/__init__.py @@ -22,6 +22,11 @@ DashScopeImageSynthesisWrapper, DashScopeTextEmbeddingWrapper, ) +from .ollama_model import ( + OllamaChatWrapper, + OllamaEmbeddingWrapper, + OllamaGenerationWrapper, +) __all__ = [ @@ -39,6 +44,9 @@ "DashScopeChatWrapper", "DashScopeImageSynthesisWrapper", "DashScopeTextEmbeddingWrapper", + "OllamaChatWrapper", + "OllamaEmbeddingWrapper", + "OllamaGenerationWrapper", ] _MODEL_CONFIGS: dict[str, dict] = {} @@ -97,7 +105,10 @@ def load_model_by_config_name(config_name: str) -> ModelWrapperBase: ) model_type = config.model_type - return _get_model_wrapper(model_type=model_type)(**config) + + kwargs = {k: v for k, v in config.items() if k != "model_type"} + + return _get_model_wrapper(model_type=model_type)(**kwargs) def clear_model_configs() -> None: diff --git a/src/agentscope/models/ollama_model.py b/src/agentscope/models/ollama_model.py new file mode 100644 index 000000000..5de732840 --- /dev/null +++ b/src/agentscope/models/ollama_model.py @@ -0,0 +1,377 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for Ollama models.""" +from typing import Sequence, Any, Optional + +from loguru import logger + +from agentscope.models import ModelWrapperBase, ModelResponse +from agentscope.utils import QuotaExceededError, MonitorFactory +from agentscope.utils.monitor import get_full_name + +try: + import ollama +except ImportError: + ollama = None + + +class OllamaWrapperBase(ModelWrapperBase): + """The base class for Ollama model wrappers. + + To use Ollama API, please + 1. First install ollama server from https://ollama.com/download and + start the server + 2. Pull the model by `ollama pull {model_name}` in terminal + After that, you can use the ollama API. + """ + + model: str + """The model name used in ollama API.""" + + options: dict + """A dict contains the options for ollama generation API, + e.g. {"temperature": 0, "seed": 123}""" + + keep_alive: str + """Controls how long the model will stay loaded into memory following + the request.""" + + def __init__( + self, + config_name: str, + model: str, + options: dict = None, + keep_alive: str = "5m", + ) -> None: + """Initialize the model wrapper for Ollama API. + + Args: + model (`str`): + The model name used in ollama API. + options (`dict`, default `None`): + The extra keyword arguments used in Ollama api generation, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `5m`): + Controls how long the model will stay loaded into memory + following the request. + """ + + super().__init__(config_name=config_name) + + self.model = model + self.options = options + self.keep_alive = keep_alive + + self.monitor = None + + self._register_default_metrics() + + def _register_default_metrics(self) -> None: + """Register metrics to the monitor.""" + raise NotImplementedError( + "The _register_default_metrics function is not Implemented.", + ) + + # TODO: move into ModelWrapperBase + def _metric(self, metric_name: str) -> str: + """Add the class name and model name as prefix to the metric name. + + Args: + metric_name (`str`): + The metric name. + + Returns: + `str`: Metric name of this wrapper. + """ + return get_full_name(name=metric_name, prefix=self.model) + + +class OllamaChatWrapper(OllamaWrapperBase): + """The model wrapper for Ollama chat API.""" + + model_type: str = "ollama_chat" + + def __call__( + self, + messages: Sequence[dict], + options: Optional[dict] = None, + keep_alive: Optional[str] = None, + **kwargs: Any, + ) -> ModelResponse: + """Generate response from the given messages. + + Args: + messages (`Sequence[dict]`): + A list of messages, each message is a dict contains the `role` + and `content` of the message. + options (`dict`, default `None`): + The extra arguments used in ollama chat API, which takes + effect only on this call, and will be merged with the + `options` input in the constructor, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `None`): + How long the model will stay loaded into memory following + the request, which takes effect only on this call, and will + override the `keep_alive` input in the constructor. + + Returns: + `ModelResponse`: + The response text in `text` field, and the raw response in + `raw` field. + """ + # step1: prepare parameters accordingly + if options is None: + options = self.options + else: + options = {**self.options, **options} + + keep_alive = keep_alive or self.keep_alive + + # step2: forward to generate response + response = ollama.chat( + model=self.model, + messages=messages, + options=options, + keep_alive=keep_alive, + **kwargs, + ) + + # step2: record the api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model, + "messages": messages, + "options": options, + "keep_alive": keep_alive, + **kwargs, + }, + json_response=response, + ) + + # step3: monitor the response + try: + prompt_tokens = response["prompt_eval_count"] + completion_tokens = response["eval_count"] + self.monitor.update( + { + "call_counter": 1, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + ) + except (QuotaExceededError, KeyError) as e: + logger.error(e.message) + + # step4: return response + return ModelResponse( + text=response["message"]["content"], + raw=response, + ) + + def _register_default_metrics(self) -> None: + """Register metrics to the monitor.""" + self.monitor = MonitorFactory.get_monitor() + self.monitor.register( + self._metric("call_counter"), + metric_unit="times", + ) + self.monitor.register( + self._metric("prompt_tokens"), + metric_unit="tokens", + ) + self.monitor.register( + self._metric("completion_tokens"), + metric_unit="token", + ) + self.monitor.register( + self._metric("total_tokens"), + metric_unit="token", + ) + + +class OllamaEmbeddingWrapper(OllamaWrapperBase): + """The model wrapper for Ollama embedding API.""" + + model_type: str = "ollama_embedding" + + def __call__( + self, + prompt: str, + options: Optional[dict] = None, + keep_alive: Optional[str] = None, + **kwargs: Any, + ) -> ModelResponse: + """Generate embedding from the given prompt. + + Args: + prompt (`str`): + The prompt to generate response. + options (`dict`, default `None`): + The extra arguments used in ollama embedding API, which takes + effect only on this call, and will be merged with the + `options` input in the constructor, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `None`): + How long the model will stay loaded into memory following + the request, which takes effect only on this call, and will + override the `keep_alive` input in the constructor. + + Returns: + `ModelResponse`: + The response embedding in `embedding` field, and the raw + response in `raw` field. + """ + # step1: prepare parameters accordingly + if options is None: + options = self.options + else: + options = {**self.options, **options} + + keep_alive = keep_alive or self.keep_alive + + # step2: forward to generate response + response = ollama.embeddings( + model=self.model, + prompt=prompt, + options=options, + keep_alive=keep_alive, + **kwargs, + ) + + # step3: record the api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model, + "prompt": prompt, + "options": options, + "keep_alive": keep_alive, + **kwargs, + }, + json_response=response, + ) + + # step4: monitor the response + try: + self.monitor.update( + {"call_counter": 1}, + prefix=self.model, + ) + except (QuotaExceededError, KeyError) as e: + logger.error(e.message) + + # step5: return response + return ModelResponse( + embedding=response["embedding"], + raw=response, + ) + + def _register_default_metrics(self) -> None: + """Register metrics to the monitor.""" + self.monitor = MonitorFactory.get_monitor() + self.monitor.register( + self._metric("call_counter"), + metric_unit="times", + ) + + +class OllamaGenerationWrapper(OllamaWrapperBase): + """The model wrapper for Ollama generation API.""" + + model_type: str = "ollama_generate" + + def __call__( + self, + prompt: str, + options: Optional[dict] = None, + keep_alive: Optional[str] = None, + **kwargs: Any, + ) -> ModelResponse: + """Generate response from the given prompt. + + Args: + prompt (`str`): + The prompt to generate response. + options (`dict`, default `None`): + The extra arguments used in ollama generation API, which takes + effect only on this call, and will be merged with the + `options` input in the constructor, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `None`): + How long the model will stay loaded into memory following + the request, which takes effect only on this call, and will + override the `keep_alive` input in the constructor. + + Returns: + `ModelResponse`: + The response text in `text` field, and the raw response in + `raw` field. + + """ + # step1: prepare parameters accordingly + if options is None: + options = self.options + else: + options = {**self.options, **options} + + keep_alive = keep_alive or self.keep_alive + + # step2: forward to generate response + response = ollama.generate( + model=self.model, + prompt=prompt, + options=options, + keep_alive=keep_alive, + ) + + # step3: record the api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model, + "prompt": prompt, + "options": options, + "keep_alive": keep_alive, + **kwargs, + }, + json_response=response, + ) + + # step4: monitor the response + try: + prompt_tokens = response["prompt_eval_count"] + completion_tokens = response["eval_count"] + self.monitor.update( + { + "call_counter": 1, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + ) + except (QuotaExceededError, KeyError) as e: + logger.error(e.message) + + # step5: return response + return ModelResponse( + text=response["response"], + raw=response, + ) + + def _register_default_metrics(self) -> None: + """Register metrics to the monitor.""" + self.monitor = MonitorFactory.get_monitor() + self.monitor.register( + self._metric("call_counter"), + metric_unit="times", + ) + self.monitor.register( + self._metric("prompt_tokens"), + metric_unit="tokens", + ) + self.monitor.register( + self._metric("completion_tokens"), + metric_unit="token", + ) + self.monitor.register( + self._metric("total_tokens"), + metric_unit="token", + ) diff --git a/tests/ollama_test.py b/tests/ollama_test.py new file mode 100644 index 000000000..197834300 --- /dev/null +++ b/tests/ollama_test.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +"""Unit test for Ollama model APIs.""" +import os +import unittest +import uuid +from unittest.mock import patch, MagicMock + +import agentscope +from agentscope.models import load_model_by_config_name +from agentscope.utils import MonitorFactory + + +class OllamaModelWrapperTest(unittest.TestCase): + """Unit test for Ollama model APIs.""" + + def setUp(self) -> None: + """Init for OllamaModelWrapperTest.""" + self.dummy_response = { + "model": "llama2", + "created_at": "2024-03-12T04:16:48.911377Z", + "message": { + "role": "assistant", + "content": ( + "Hello! It's nice to meet you. Is there something I can " + "help you with or would you like to chat?", + ), + }, + "done": True, + "total_duration": 20892900042, + "load_duration": 20019679292, + "prompt_eval_count": 22, + "prompt_eval_duration": 149094000, + "eval_count": 26, + "eval_duration": 721982000, + } + + self.dummy_embedding = { + "embedding": [1.0, 2.0, 3.0], + } + + self.dummy_generate = { + "model": "llama2", + "created_at": "2024-03-12T03:42:19.621919Z", + "response": "\n1 + 1 = 2", + "done": True, + "context": [ + 518, + 25580, + 29962, + 3532, + 14816, + 29903, + 29958, + 5299, + 829, + 14816, + 29903, + 6778, + 13, + 13, + 29896, + 29974, + 29896, + 29922, + 518, + 29914, + 25580, + 29962, + 13, + 13, + 29896, + 718, + 29871, + 29896, + 353, + 29871, + 29906, + ], + "total_duration": 6146120041, + "load_duration": 6677375, + "prompt_eval_count": 9, + "prompt_eval_duration": 5913554000, + "eval_count": 9, + "eval_duration": 223689000, + } + self.tmp = MonitorFactory._instance # pylint: disable=W0212 + MonitorFactory._instance = None # pylint: disable=W0212 + self.db_path = f"test-{uuid.uuid4()}.db" + _ = MonitorFactory.get_monitor(db_path=self.db_path) + + @patch("ollama.chat") + def test_ollama_chat(self, mock_chat: MagicMock) -> None: + """Unit test for ollama chat API.""" + # prepare the mock + mock_chat.return_value = self.dummy_response + + # run test + agentscope.init( + model_configs={ + "config_name": "my_ollama_chat", + "model_type": "ollama_chat", + "model": "llama2", + "options": { + "temperature": 0.5, + }, + "keep_alive": "5m", + }, + ) + + model = load_model_by_config_name("my_ollama_chat") + response = model(messages=[{"role": "user", "content": "Hi!"}]) + + self.assertEqual(response.raw, self.dummy_response) + + @patch("ollama.embeddings") + def test_ollama_embedding(self, mock_embeddings: MagicMock) -> None: + """Unit test for ollama embeddings API.""" + # prepare the mock + mock_embeddings.return_value = self.dummy_embedding + + # run test + agentscope.init( + model_configs={ + "config_name": "my_ollama_embedding", + "model_type": "ollama_embedding", + "model": "llama2", + "options": { + "temperature": 0.5, + }, + "keep_alive": "5m", + }, + ) + + model = load_model_by_config_name("my_ollama_embedding") + response = model(prompt="Hi!") + + self.assertEqual(response.raw, self.dummy_embedding) + + @patch("ollama.generate") + def test_ollama_generate(self, mock_generate: MagicMock) -> None: + """Unit test for ollama generate API.""" + # prepare the mock + mock_generate.return_value = self.dummy_generate + + # run test + agentscope.init( + model_configs={ + "config_name": "my_ollama_generate", + "model_type": "ollama_generate", + "model": "llama2", + "options": None, + "keep_alive": "5m", + }, + ) + + model = load_model_by_config_name("my_ollama_generate") + response = model(prompt="1+1=") + + self.assertEqual(response.raw, self.dummy_generate) + + def tearDown(self) -> None: + """Clean up after each test.""" + MonitorFactory._instance = self.tmp # pylint: disable=W0212 + os.remove(self.db_path) + + +if __name__ == "__main__": + unittest.main()