From 45586d830ef4597b6969805fce9ef9093740425e Mon Sep 17 00:00:00 2001 From: wyp311395 Date: Mon, 6 Jan 2025 15:04:06 +0800 Subject: [PATCH] add muagen sdk v0.1.0 modify config modify config modify config --- .github/workflows/docker-image-pull.yml | 4 +- .github/workflows/docker-image.yml | 4 +- docker-compose.yaml | 2 +- docker_pull_images.sh | 8 +- examples/ekg_examples/start.py | 103 ++- examples/muagent_examples/docchat_example.py | 3 +- examples/test_config.py.example | 106 ++- muagent/__init__.py | 18 +- muagent/agents/__init__.py | 30 + muagent/agents/agent_util.py | 202 ++++++ muagent/agents/base_agent.py | 504 +++++++++++++ muagent/agents/functioncall_agent.py | 237 +++++++ muagent/agents/group_agent.py | 227 ++++++ muagent/agents/react_agent.py | 284 ++++++++ muagent/agents/single_agent.py | 205 ++++++ muagent/agents/task_agent.py | 291 ++++++++ muagent/agents/user_agent.py | 140 ++++ muagent/agents/util.py | 0 muagent/base_configs/env_config.py | 16 +- .../prompts/functioncall_template_prompt.py | 36 + .../prompts/intention_template_prompt.py | 18 +- .../memory/hierarchical_memory_manager.py | 3 +- muagent/connector/memory_manager.py | 3 +- muagent/connector/schema/general_schema.py | 6 +- muagent/connector/utils.py | 18 +- muagent/db_handler/__init__.py | 22 +- muagent/{orm => db_handler}/db.py | 0 .../graph_db_handler/nebula_handler.py | 25 +- .../vector_db_handler/local_faiss_handler.py | 21 +- muagent/ekg_project.py | 659 +++++++++++++++++ muagent/httpapis/ekg_construct/api.py | 30 +- muagent/llm_models/llm_shemas.py | 50 ++ muagent/llm_models/openai_model.py | 130 +++- muagent/memory/__init__.py | 5 - muagent/memory_manager/__init__.py | 13 + muagent/memory_manager/base_memory_manager.py | 271 +++++++ .../hierarchical_memory_manager.py | 4 +- .../memory_manager/local_memory_manager.py | 443 ++++++++++++ .../memory_manager/tbase_memory_manager.py | 628 +++++++++++++++++ muagent/models/__init__.py | 104 +++ muagent/models/base_model.py | 504 +++++++++++++ muagent/models/dashscope_model.py | 514 ++++++++++++++ muagent/models/kimi_model.py | 372 ++++++++++ muagent/models/ollama_model.py | 490 +++++++++++++ muagent/models/openai_model.py | 667 ++++++++++++++++++ muagent/models/qwen_model.py | 461 ++++++++++++ muagent/models/yi_model.py | 291 ++++++++ muagent/orm/__init__.py | 23 - muagent/project_manager.py | 70 ++ muagent/prompt_manager/__init__.py | 8 + muagent/prompt_manager/base.py | 32 + muagent/prompt_manager/base_prompt_manager.py | 506 +++++++++++++ .../prompt_manager/common_prompt_manager.py | 320 +++++++++ muagent/prompt_manager/language/en.py | 89 +++ muagent/prompt_manager/language/zh.py | 87 +++ muagent/prompt_manager/util.py | 94 +++ muagent/sandbox/__init__.py | 3 +- muagent/sandbox/nbclient.py | 297 ++++++++ muagent/schemas/__init__.py | 12 + muagent/schemas/agent_config.py | 68 ++ muagent/schemas/apis/ekg_api_schema.py | 21 +- muagent/schemas/common/__init__.py | 8 +- muagent/schemas/common/actions.py | 76 ++ muagent/schemas/common/log.py | 38 + muagent/schemas/kb/base_schema.py | 2 +- muagent/schemas/memory.py | 193 +++++ muagent/schemas/message.py | 258 +++++++ muagent/schemas/models/__init__.py | 11 + muagent/schemas/models/llm_shemas.py | 50 ++ muagent/schemas/models/model.py | 56 ++ muagent/schemas/project_config.py | 114 +++ .../ekg_construct/ekg_construct_base.py | 40 +- .../ekg_inference/intention_match_rule.py | 7 +- .../service/ekg_inference/intention_router.py | 13 +- .../service/ui_file_service/code_base_cds.py | 2 +- .../ui_file_service/document_base_cds.py | 2 +- .../ui_file_service/document_file_cds.py | 2 +- muagent/service/utils.py | 18 +- muagent/tools/__init__.py | 7 +- muagent/tools/base_tool.py | 105 ++- muagent/tools/metrics_query.py | 3 +- muagent/tools/undercover.py | 315 +++++++++ muagent/tools/werewolf.py | 314 +++++++++ muagent/utils/common_utils.py | 45 ++ requirements.txt | 2 +- setup.py | 6 +- tests/agents/funccall_agent_test.py | 97 +++ tests/agents/group_agent_test.py | 71 ++ tests/agents/react_agent_test.py | 62 ++ tests/agents/single_agent_test.py | 119 ++++ tests/agents/task_agent_test.py | 66 ++ tests/llm_models/embedding_test.py | 48 ++ tests/llm_models/model_test.py | 93 +++ .../local_memory_manager_test.py | 161 +++++ tests/memory_manager/local_mm_crud_test.py | 117 +++ .../tbase_memory_manager_test.py | 153 ++++ tests/orm/table_test.py | 3 +- tests/prompt_manager/base_test.py | 116 +++ tests/prompt_manager/extend_common_pm_test.py | 161 +++++ tests/prompt_manager/new_pm_test.py | 233 ++++++ tests/retrieval/faiss_test.py | 69 ++ tests/sandbox/nbclient_test.py | 218 ++++++ tests/service/ekg_project_test.py | 423 +++++++++++ tests/test_config.py.example | 104 +++ tests/tools/get_tool.py | 30 + 105 files changed, 13656 insertions(+), 177 deletions(-) create mode 100644 muagent/agents/__init__.py create mode 100644 muagent/agents/agent_util.py create mode 100644 muagent/agents/base_agent.py create mode 100644 muagent/agents/functioncall_agent.py create mode 100644 muagent/agents/group_agent.py create mode 100644 muagent/agents/react_agent.py create mode 100644 muagent/agents/single_agent.py create mode 100644 muagent/agents/task_agent.py create mode 100644 muagent/agents/user_agent.py create mode 100644 muagent/agents/util.py create mode 100644 muagent/base_configs/prompts/functioncall_template_prompt.py rename muagent/{orm => db_handler}/db.py (100%) create mode 100644 muagent/ekg_project.py create mode 100644 muagent/llm_models/llm_shemas.py delete mode 100644 muagent/memory/__init__.py create mode 100644 muagent/memory_manager/__init__.py create mode 100644 muagent/memory_manager/base_memory_manager.py rename muagent/{memory => memory_manager}/hierarchical_memory_manager.py (96%) create mode 100644 muagent/memory_manager/local_memory_manager.py create mode 100644 muagent/memory_manager/tbase_memory_manager.py create mode 100644 muagent/models/__init__.py create mode 100644 muagent/models/base_model.py create mode 100644 muagent/models/dashscope_model.py create mode 100644 muagent/models/kimi_model.py create mode 100644 muagent/models/ollama_model.py create mode 100644 muagent/models/openai_model.py create mode 100644 muagent/models/qwen_model.py create mode 100644 muagent/models/yi_model.py delete mode 100644 muagent/orm/__init__.py create mode 100644 muagent/project_manager.py create mode 100644 muagent/prompt_manager/__init__.py create mode 100644 muagent/prompt_manager/base.py create mode 100644 muagent/prompt_manager/base_prompt_manager.py create mode 100644 muagent/prompt_manager/common_prompt_manager.py create mode 100644 muagent/prompt_manager/language/en.py create mode 100644 muagent/prompt_manager/language/zh.py create mode 100644 muagent/prompt_manager/util.py create mode 100644 muagent/sandbox/nbclient.py create mode 100644 muagent/schemas/agent_config.py create mode 100644 muagent/schemas/common/actions.py create mode 100644 muagent/schemas/common/log.py create mode 100644 muagent/schemas/memory.py create mode 100644 muagent/schemas/message.py create mode 100644 muagent/schemas/models/__init__.py create mode 100644 muagent/schemas/models/llm_shemas.py create mode 100644 muagent/schemas/models/model.py create mode 100644 muagent/schemas/project_config.py create mode 100644 muagent/tools/undercover.py create mode 100644 muagent/tools/werewolf.py create mode 100644 tests/agents/funccall_agent_test.py create mode 100644 tests/agents/group_agent_test.py create mode 100644 tests/agents/react_agent_test.py create mode 100644 tests/agents/single_agent_test.py create mode 100644 tests/agents/task_agent_test.py create mode 100644 tests/llm_models/embedding_test.py create mode 100644 tests/llm_models/model_test.py create mode 100644 tests/memory_manager/local_memory_manager_test.py create mode 100644 tests/memory_manager/local_mm_crud_test.py create mode 100644 tests/memory_manager/tbase_memory_manager_test.py create mode 100644 tests/prompt_manager/base_test.py create mode 100644 tests/prompt_manager/extend_common_pm_test.py create mode 100644 tests/prompt_manager/new_pm_test.py create mode 100644 tests/retrieval/faiss_test.py create mode 100644 tests/sandbox/nbclient_test.py create mode 100644 tests/service/ekg_project_test.py create mode 100644 tests/tools/get_tool.py diff --git a/.github/workflows/docker-image-pull.yml b/.github/workflows/docker-image-pull.yml index ca64365..2c3d34a 100644 --- a/.github/workflows/docker-image-pull.yml +++ b/.github/workflows/docker-image-pull.yml @@ -11,8 +11,8 @@ jobs: architecture: [amd64, arm64] os: [linux] service: - - name: runtime:0.1.0 - - name: muagent:0.1.0 + - name: runtime:0.1.1 + - name: muagent:0.1.1 - name: ekgfrontend:0.1.0 steps: diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 04e0c50..3dcd8ba 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -12,7 +12,7 @@ jobs: - name: runtime context: ./runtime dockerfile: ./runtime/Dockerfile.no-package - tag: ghcr.io/codefuse-ai/runtime:0.1.0 + tag: ghcr.io/codefuse-ai/runtime:0.1.1 tag_latest: ghcr.io/codefuse-ai/runtime:latest - name: ekgfrontend context: . @@ -22,7 +22,7 @@ jobs: - name: ekgservice context: . dockerfile: ./Dockerfile_gh - tag: ghcr.io/codefuse-ai/muagent:0.1.0 + tag: ghcr.io/codefuse-ai/muagent:0.1.1 tag_latest: ghcr.io/codefuse-ai/muagent:latest steps: diff --git a/docker-compose.yaml b/docker-compose.yaml index 84cbafd..ed26951 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -190,7 +190,7 @@ services: context: . dockerfile: Dockerfile container_name: ekgservice - image: muagent:0.1.0 + image: muagent:0.1.1 environment: USER: root TZ: "${TZ}" diff --git a/docker_pull_images.sh b/docker_pull_images.sh index 82998c1..85bd44a 100644 --- a/docker_pull_images.sh +++ b/docker_pull_images.sh @@ -17,11 +17,11 @@ docker pull redis/redis-stack:7.4.0-v0 docker pull ollama/ollama:0.3.6 # pull images from github ghcr.io by nju -docker pull ghcr.nju.edu.cn/runtime:0.1.0 -docker pull ghcr.nju.edu.cn/muagent:0.1.0 +docker pull ghcr.nju.edu.cn/runtime:0.1.1 +docker pull ghcr.nju.edu.cn/muagent:0.1.1 docker pull ghcr.nju.edu.cn/ekgfrontend:0.1.0 # # pull images from github ghcr.io -# docker pull ghcr.io/runtime:0.1.0 -# docker pull ghcr.io/muagent:0.1.0 +# docker pull ghcr.io/runtime:0.1.1 +# docker pull ghcr.io/muagent:0.1.1 # docker pull ghcr.io/ekgfrontend:0.1.0 diff --git a/examples/ekg_examples/start.py b/examples/ekg_examples/start.py index 838ca87..0e1ed81 100644 --- a/examples/ekg_examples/start.py +++ b/examples/ekg_examples/start.py @@ -37,6 +37,7 @@ import test_config from muagent.schemas.db import * +from muagent.schemas.apis.ekg_api_schema import LLMFCRequest from muagent.db_handler import * from muagent.llm_models.llm_config import EmbedConfig, LLMConfig from muagent.service.ekg_construct.ekg_construct_base import EKGConstructService @@ -46,7 +47,8 @@ from pydantic import BaseModel - +from muagent.schemas.models import ModelConfig +from muagent.models import get_model cur_dir = os.path.dirname(__file__) @@ -92,56 +94,75 @@ def update_params(self, **kwargs): def _llm_type(self, *args): return "" - - def predict(self, prompt: str, stop = None) -> str: - return self._call(prompt, stop) - - def _call(self, prompt: str, - stop = None) -> str: + + def _get_model(self): """_call """ - return_str = "" - stop = stop or self.stop - - if self.model_type == "ollama": - stream = ollama.chat( - model=self.model_name, - messages=[{'role': 'user', 'content': prompt}], - stream=True, - ) - answer = "" - for chunk in stream: - answer += chunk['message']['content'] - - return answer - elif self.model_type == "openai": + if self.model_type in [ + "ollama", "qwen", "openai", "lingyiwanwu", + "kimi", "moonshot", + ]: from muagent.llm_models.openai_model import getChatModelFromConfig llm_config = LLMConfig( model_name=self.model_name, - model_engine="openai", + model_engine=self.model_type, api_key=self.api_key, api_base_url=self.url, temperature=self.temperature, stop=self.stop ) model = getChatModelFromConfig(llm_config) - return model.predict(prompt, stop=self.stop) - elif self.model_type in ["lingyiwanwu", "kimi", "moonshot", "qwen"]: - from muagent.llm_models.openai_model import getChatModelFromConfig - llm_config = LLMConfig( + else: + model_config = ModelConfig( + model_type=self.model_type, model_name=self.model_name, - model_engine=self.model_type, api_key=self.api_key, - api_base_url=self.url, + api_url=self.url, temperature=self.temperature, - stop=self.stop ) - model = getChatModelFromConfig(llm_config) - return model.predict(prompt, stop=self.stop) - else: - pass + model = get_model(model_config) + return model + + def predict(self, prompt: str, stop = None) -> str: + return self._call(prompt, stop) - return return_str + def fc(self, request: LLMFCRequest) -> str: + """_function_call + """ + if self.model_type not in [ + "openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen" + ]: + return f"{self.model_type} not in valid model range" + + model = self._get_model() + return model.fc( + messages=request.messages, + tools=request.tools, + tool_choice=request.tool_choice, + parallel_tool_calls=request.parallel_tool_calls, + ) + + def _call(self, prompt: str, + stop = None) -> str: + """_call + """ + return_str = "" + stop = stop or self.stop + if self.model_type not in [ + "openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen" + ]: + pass + elif self.model_type not in [ + "dashscope_chat", "moonshot_chat", "ollama_chat", + "openai_chat", "qwen_chat", "yi_chat", + "dashscope_text_embedding", "ollama_embedding", "openai_embedding", "qwen_text_embedding" + ]: + pass + else: + return f"{self.model_type} not in valid model range" + + model = self._get_model() + return model.predict(prompt, stop=self.stop) class CustomEmbeddings(Embeddings): @@ -185,6 +206,17 @@ def _get_sentence_emb(self, sentence: str) -> dict: ) text2vector_dict = get_embedding("openai", [sentence], embed_config=embed_config) return text2vector_dict[sentence] + elif self.embedding_type in [ + "dashscope_text_embedding", "ollama_embedding", "openai_embedding", "qwen_text_embedding" + ]: + model_config = ModelConfig( + model_type=self.embedding_type, + model_name=self.model_name, + api_key=self.api_key, + api_url=self.url, + ) + model = get_model(model_config) + return model.embed_query(sentence) else: pass @@ -280,6 +312,7 @@ def embed_query(self, text: str) -> List[float]: llm_config=llm_config, tb_config=tb_config, gb_config=gb_config, + initialize_space=True, clear_history_data=clear_history_data ) diff --git a/examples/muagent_examples/docchat_example.py b/examples/muagent_examples/docchat_example.py index 1ba7ed4..89174cd 100644 --- a/examples/muagent_examples/docchat_example.py +++ b/examples/muagent_examples/docchat_example.py @@ -60,7 +60,8 @@ # create your knowledge base from muagent.service.kb_api import create_kb, upload_files2kb from muagent.utils.server_utils import run_async -from muagent.orm import create_tables +# from muagent.orm import create_tables +from muagent.db_handler import create_tables # use to test, don't create some directory diff --git a/examples/test_config.py.example b/examples/test_config.py.example index efc3603..ac03016 100644 --- a/examples/test_config.py.example +++ b/examples/test_config.py.example @@ -1,6 +1,8 @@ import os, openai, base64 from loguru import logger +os.environ["DM_llm_name"] = 'Qwen2_72B_Instruct_OpsGPT' #or gpt_4 + # 兜底大模型配置 OPENAI_API_BASE = "https://api.openai.com/v1" os.environ["API_BASE_URL"] = OPENAI_API_BASE @@ -19,6 +21,78 @@ os.environ["gpt4-llm_temperature"] = "0.0" +MODEL_CONFIGS = { + # old llm config + "default": { + "model_name": "gpt-3.5-turbo", + "model_engine": "qwen", + "temperature": "0", + "api_key": "", + "api_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + }, + "codefuser":{ + "model_name": "gpt-4", + "model_engine": "openai", + "temperature": "0", + "api_key": "", + "api_base_url": OPENAI_API_BASE, + }, + # new llm config + "dashscope_chat": { + "model_type": "dashscope_chat", + "model_name": "qwen2.5-72b-instruct" , + "api_key": "", + }, + "moonshot_chat": { + "model_type": "moonshot_chat", + "model_name": "moonshot-v1-8k" , + "api_key": "", + }, + "ollama_chat": { + "model_type": "ollama_chat", + "model_name": "qwen2.5-0.5b", + "api_key": "", + }, + "openai_chat": { + "model_type": "openai_chat", + "model_name": "gpt-4", + "api_key": "", + }, + "qwen_chat": { + "model_type": "qwen_chat", + "model_name": "qwen2.5-72b-instruct", + "api_key": "", + }, + "yi_chat": { + "model_type": "yi_chat", + "model_name": "yi-lightning" , + "api_key": "", + }, + # embedding configs + "dashscope_text_embedding": { + "model_type": "dashscope_text_embedding", + "model_name": "text-embedding-v3", + "api_key": "", + }, + "ollama_embedding": { + "model_type": "ollama_embedding", + "model_name": "qwen2.5-0.5b", + "api_key": "", + }, + "openai_embedding": { + "model_type": "openai_embedding", + "model_name": "text-embedding-ada-002", + "api_key": "", + }, + "qwen_text_embedding": { + "model_type": "dashscope_text_embedding", + "model_name": "text-embedding-v3", + "api_key": "", + }, +} + +os.environ["MODEL_CONFIGS"] = json.dumps(MODEL_CONFIGS) + #### NebulaHandler #### os.environ['nb_host'] = 'graphd' os.environ['nb_port'] = '9669' @@ -41,8 +115,36 @@ os.environ["tb_index_name"] = "ekg_migration_new" os.environ['tb_definition_value'] = 'message_test_new' os.environ['tb_expire_time'] = '604800' #86400*7 -# clear history data in tb and gb -os.environ['clear_history_data'] = 'True' + +################# +## DB_CONFIGS ## +################# +DB_CONFIGS = { + "gb_config": { + "gb_type": "NebulaHandler", + "extra_kwargs": { + 'host':'graphd', + 'port': '9669', + 'username': os.environ['nb_username'], + 'password': os.environ['nb_password'], + 'space': "client" + } + }, + "tb_config": { + "tb_type": 'TBaseHandler', + "index_name": "opsgptkg", + "host": 'redis-stack', + "port": '6379', + "username": os.environ['tb_username'], + "password": os.environ['tb_password'], + "extra_kwargs": { + "definition_value": "opsgptkg", + "memory_definition_value": "opsgptkg_message" + } + } +} +os.environ["DB_CONFIGS"] = json.dumps(DB_CONFIGS) + ######################################## diff --git a/muagent/__init__.py b/muagent/__init__.py index 67c87c2..a079b65 100644 --- a/muagent/__init__.py +++ b/muagent/__init__.py @@ -1,7 +1,11 @@ -# encoding: utf-8 -''' -@author: 温进 -@file: __init__.py.py -@time: 2023/11/9 下午4:01 -@desc: -''' \ No newline at end of file +from .ekg_project import EKG, get_ekg_project_config_from_env +from .project_manager import get_project_config_from_env +from .models import get_model +from .agents import get_agent +from .tools import get_tool + +__all__ = [ + "EKG", "get_model", "get_agent", "get_tool", + "get_ekg_project_config_from_env", + "get_project_config_from_env" +] \ No newline at end of file diff --git a/muagent/agents/__init__.py b/muagent/agents/__init__.py new file mode 100644 index 0000000..271b61f --- /dev/null +++ b/muagent/agents/__init__.py @@ -0,0 +1,30 @@ +from .base_agent import BaseAgent +from .single_agent import SingleAgent +from .react_agent import ReactAgent +from .task_agent import TaskAgent +from .group_agent import GroupAgent +from .user_agent import UserAgent +from .functioncall_agent import FunctioncallAgent +from ..schemas import AgentConfig + +__all__ = [ + "BaseAgent", + "SingleAgent", + "ReactAgent", + "TaskAgent", + "GroupAgent", + "UserAgent", + "FunctioncallAgent" +] + + +def get_agent(agent_config: AgentConfig) -> BaseAgent: + """Get the agent by agent config + + Args: + agent_config (`AgentConfig`): The agent config + + Returns: + `BaseAgent`: The specific agent + """ + return BaseAgent.init_from_project_config(agent_config) \ No newline at end of file diff --git a/muagent/agents/agent_util.py b/muagent/agents/agent_util.py new file mode 100644 index 0000000..a5561ae --- /dev/null +++ b/muagent/agents/agent_util.py @@ -0,0 +1,202 @@ +import re, uuid, os +from typing import ( + Union, + Tuple, + List +) +from loguru import logger + +from ..schemas import Memory, Message +from ..schemas.common import ActionStatus, LogVerboseEnum +from ..tools import get_tool +from ..sandbox import NBClientBox + +from muagent.base_configs.env_config import KB_ROOT_PATH + +class MessageUtil: + """Utility class for processing messages and executing code or tools based on message content.""" + + def __init__( + self, + workdir_path: str = KB_ROOT_PATH, + log_verbose: str = "0", + **kwargs + ) -> None: + """Initialize the MessageUtil with the specified working directory and log verbosity. + + Args: + workdir_path (str): Path to the working directory where files may be saved. + log_verbose (str): Verbosity level for logging. + **kwargs: Additional keyword arguments for future extensions. + """ + self.codebox = NBClientBox(do_code_exe=True) # Initialize code execution box + + self.workdir_path = workdir_path # Set the working directory path + self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose # Configure logging verbosity + + def step_router( + self, + msg: Message, + session_index: str = "", + **kwargs + ) -> Tuple[Message, ...]: + """Route a message to the appropriate step for processing based on its action status. + + Args: + msg (Message): The input message that needs processing. + session_index (str): The session identifier for managing the conversation. + **kwargs: Additional parameters for processing. + + Returns: + Tuple[Message, ...]: The processed message and any observation message. + """ + session_index = msg.session_index or session_index or str(uuid.uuid4()) + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"message.action_status: {msg.action_status}") + + observation_msg = None + + # Determine the action to take based on the message's action status + if msg.action_status == ActionStatus.CODE_EXECUTING: + msg, observation_msg = self.code_step(msg, session_index) + elif msg.action_status == ActionStatus.TOOL_USING: + msg, observation_msg = self.tool_step(msg, session_index, **kwargs) + elif msg.action_status == ActionStatus.CODING2FILE: + self.save_code2file(msg, self.workdir_path) + # Handle other action statuses as needed (currently no operations for these) + elif msg.action_status == ActionStatus.CODE_RETRIEVAL: + pass + elif msg.action_status == ActionStatus.CODING: + pass + + return msg, observation_msg + + def code_step(self, msg: Message, session_index: str) -> Message: + """Execute code contained in the message. + + Args: + msg (Message): The message containing code to be executed. + session_index (str): The session identifier for managing the conversation. + + Returns: + Tuple[Message, Message]: The processed message and an observation message regarding code execution. + """ + # Execute the code using the codebox and capture the result + code_answer = self.codebox.chat( + '```python\n{}```'.format(msg.spec_parsed_content.get("code_content", "")) + ) + + # Prepare a response message based on code execution result + code_prompt = ( + f"The return error after executing the above code is {code_answer.code_exe_response},need to recover.\n" + if code_answer.code_exe_type == "error" else + f"The return information after executing the above code is {code_answer.code_exe_response}.\n" + ) + + # Create an observation message for logging code execution outcome + observation_msg = Message( + session_index=session_index, + role_name="function", + role_type="observation", + content="", + step_content="", + input_text=msg.spec_parsed_content.get("code_content", ""), + ) + + uid = str(uuid.uuid1()) # Generate a unique identifier for related content + if code_answer.code_exe_type == "image/png": + # If the code execution produces an image, log the result and update the message + msg.global_kwargs[uid] = code_answer.code_exe_response + msg.step_content += f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n" + msg.parsed_contents.append({"Observation": f"The return figure name is {uid} after executing the above code.\n"}) + observation_msg.update_content(f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n") + observation_msg.update_parsed_content({"Observation": f"The return figure name is {uid} after executing the above code.\n"}) + else: + # Log the standard execution result + msg.step_content += f"\n**Observation:**: {code_prompt}\n" + observation_msg.update_content(code_prompt) + observation_msg.update_parsed_content({"Observation": f"{code_prompt}\n"}) + + # Log the observations at the defined verbosity level + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"**Observation:** {msg.action_status}, {observation_msg.content}") + + return msg, observation_msg + + def tool_step( + self, + msg: Message, + session_index: str, + **kwargs + ) -> Message: + """Execute a tool based on parameters in the message. + + Args: + msg (Message): The message that specifies the tool to be executed. + session_index (str): The session identifier for managing the conversation. + **kwargs: Additional parameters for processing, including available tools. + + Returns: + Tuple[Message, ...]: + The processed message and an observation message regarding the tool execution. + """ + no_tool_msg = "\n**Observation:** there is no tool can execute.\n" # Message for missing tool + tool_names = kwargs.get("tools") # Retrieve available tool names + extra_params = kwargs.get("extra_params", {}) + tool_param = msg.spec_parsed_content.get("tool_param", {}) # Parameters for the tool execution + tool_param.update(extra_params) + tool_name = msg.spec_parsed_content.get("tool_name", "") # Name of the tool to execute + + # Create a message to log the tool execution result + observation_msg = Message( + session_index=session_index, + role_name="function", + role_type="observation", + input_text=str(tool_param), + ) + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"message: {msg.action_status}, {tool_param}") + + if tool_name not in tool_names: + msg.step_content += f"\n{no_tool_msg}" + observation_msg.update_content(no_tool_msg) + observation_msg.update_parsed_content({"Observation": no_tool_msg}) + else: + # Execute the specified tool and capture the result + tool = get_tool(tool_name) + tool_res = tool.run(**tool_param) + msg.step_content += f"\n**Observation:** {tool_res}.\n" + msg.parsed_contents.append({"Observation": f"{tool_res}.\n"}) + observation_msg.update_content(f"**Observation:** {tool_res}.\n") + observation_msg.update_parsed_content({"Observation": f"{tool_res}.\n"}) + + # Log the observations at the defined verbosity level + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"**Observation:** {msg.action_status}, {observation_msg.content}") + + return msg, observation_msg + + def save_code2file(self, msg: Message, project_dir="./"): + """Save the code from the message to a specified file. + + Args: + msg (Message): The message containing the code to be saved. + project_dir (str): Directory path where the code file will be saved. + """ + filename = msg.parsed_content.get("SaveFileName") # Retrieve filename from message content + code = msg.spec_parsed_content.get("code") # Extract code content from the message + + # Replace HTML entities in the code + for k, v in {">": ">", "≥": ">=", "<": "<", "≤": "<="}.items(): + code = code.replace(k, v) + + project_dir_path = os.path.join(self.workdir_path, project_dir) # Construct project directory path + file_path = os.path.join(project_dir_path, filename) # Full path for the output code file + + # Create directories if they don't exist + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # Write the code to the file + with open(file_path, "w") as f: + f.write(code) diff --git a/muagent/agents/base_agent.py b/muagent/agents/base_agent.py new file mode 100644 index 0000000..47de07c --- /dev/null +++ b/muagent/agents/base_agent.py @@ -0,0 +1,504 @@ +from abc import ABCMeta +from pydantic import BaseModel +import os +from typing import ( + List, + Union, + Generator, + Any, + Type, + Optional, + Literal +) +import copy +from loguru import logger + +from ..schemas import ( + Message, + Memory, + PromptConfig, + AgentConfig, + ProjectConfig +) +from ..schemas.models import ModelConfig +from ..schemas.models import LLMConfig as TempLLMConfig +from ..memory_manager import BaseMemoryManager +from ..prompt_manager import BasePromptManager +from ..models import ModelWrapperBase, get_model + +from .agent_util import MessageUtil +from muagent.connector.schema import LogVerboseEnum +from muagent.llm_models import getChatModelFromConfig + + +class _AgentWapperBase(ABCMeta): + """A meta class to replace the tool wrapper's run function with + a wrapper that handles errors gracefully. + """ + + def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any: + if "__call__" in attrs: + attrs["__call__"] = attrs["__call__"] + return super().__new__(mcs, name, bases, attrs) + + def __init__(cls, name: Any, bases: Any, attrs: Any) -> None: + # Initialize class-level registries for storing agent classes + if not hasattr(cls, "_registry"): + cls._registry = {} # Registry of agent class names + cls._type_registry = {} # Registry of agent class type names + else: + # Register the current class in the registry + cls._registry[name] = cls + cls._type_registry[cls.agent_type] = cls + super().__init__(name, bases, attrs) + + +class BaseAgent(metaclass=_AgentWapperBase): + """Base class for agents, providing initialization and interaction methods. + + You can define your custom agent for your agent work, such as + .. code-block:: python + + from muagent.schemas.message import BaseAgent + + class SingleAgent(BaseAgent): + """""" + agent_type: str = "SingleAgent" + """""" + agent_id: str + """""" + def __init__( + self, + agent_name: str = "codefuse_simpler", + system_prompt: str = "", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = None, + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + log_verbose: str = "0", + ): + + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template, + prompt=prompt, + agents=agents, + tools=tools, + agent_desc=agent_desc, + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config=project_config, + log_verbose=log_verbose + ) + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default" + ) -> Generator[Message, None, None]: + '''agent response from multi-message''' + session_index = query.session_index or session_index + # insert query into memory + ... + # transform query into output_message.input_text + ... + # get memory from self or memory_manager + ... + # generate prompt by prompt manager + ... + # predict + ... + # update infomation + ... + # common parse llm' content to message + ... + # todo: action step + ... + # end + ... + # update self_memory and memory pool + ... + def pre_print( + self, + query: Message, + memory_manager: BaseMemoryManager=None, + tools: List[str] = [], + session_index: str = "default" + + ) -> None: + pass + """ + + agent_type: str = "BaseAgent" + """Defines the type of the agent (default is BaseAgent).""" + + agent_id: str + """Unique identifier for the agent.""" + + def __init__( + self, + agent_name: str = "codefuse_baser", + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = None, + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + log_verbose: str = "0", + ): + # Configure logging verbosity + self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose) + + # Initialize agent properties + self.agent_name = agent_name + self.system_prompt = system_prompt + self.input_template = input_template + self.output_template = output_template + self.prompt = prompt + self.agent_desc = agent_desc + self.agents = agents + self.tools = tools + self.agent_config = agent_config + self.prompt_config = prompt_config + self.model_config = model_config + self.project_config = project_config + # + self.memory: Memory = Memory() + self.message_util = MessageUtil() + + # Initialize agent from configuration data + self._init_from_configs() + + def _init_from_configs(self): + '''Initialize agent's configuration from provided parameters.''' + if not self.agent_name: + raise ValueError( + f"Init a agent must have a agent name." + ) + # Load configurations + self._init_agent_config() + self._init_model_config() + self._init_prompt_config() + + def _init_agent_config(self): + '''Initialize agent configuration (AgentConfig).''' + # Load agent configuration based on the agent name and project config + if self.agent_name and self.project_config and self.project_config.agent_configs: + tmp_agent_config = self.project_config.agent_configs.get(self.agent_name) + self.agent_config = self.agent_config or tmp_agent_config + + if self.agent_config and isinstance(self.agent_config, AgentConfig): + # Set agent properties from the configuration + self.agent_name = self.agent_config.agent_name + self.system_prompt = self.system_prompt or self.agent_config.system_prompt + self.input_template = self.input_template or self.agent_config.input_template + self.output_template = self.output_template or self.agent_config.output_template + self.prompt = self.prompt or self.agent_config.prompt + self.agent_desc = self.agent_desc or self.agent_config.agent_desc + self.tools = self.tools or self.agent_config.tools or self.tools + self.agents = self.agents or self.agent_config.agents + self._llm_config_name = self.agent_config.llm_config_name + self._em_config_name = self.agent_config.em_config_name + self._prompt_config_name = self.agent_config.prompt_config_name + + def _init_model_config(self): + '''Initialize model configuration (ModelConfig).''' + # Check if model_config was provided + if self.model_config: + pass + # Load model configuration from project config if not provided + elif self.agent_name and self.project_config and self.project_config.model_configs: + if self._llm_config_name in self.project_config.model_configs: + self.model_config = self.project_config.model_configs[self._llm_config_name] + elif "default_chat" in self.project_config.model_configs: + self.model_config = self.project_config.model_configs["default_chat"] + else: + raise ValueError( + f"While init a model, project_config must have model configs. " + f"However, there is something wrong in agent_name: {self.agent_name} " + f"agent_config: {self.project_config.model_configs} " + ) + else: + raise ValueError( + f"While init a model, it must have model config. " + f"However, there is something wrong in agent_name: {self.agent_name} " + f"agent_config: {self.project_config} " + ) + + def _init_prompt_config(self): + '''Initialize prompt configuration (PromptConfig).''' + # Load prompt configuration based on the agent's name and project config + if self.agent_name and self.project_config and self.project_config.prompt_configs: + self.prompt_config = self.project_config.prompt_configs.get( + self.agent_name, PromptConfig() + ) + self._init_prompt_manager() + else: + self.prompt_config = PromptConfig() # Fallback to default prompt config + self._init_prompt_manager() + + def _init_prompt_manager(self): + '''Initialize prompt manager from prompt configurations.''' + self.prompt_manager = BasePromptManager.from_config( + system_prompt=self.system_prompt, + input_template=self.input_template, + output_template=self.output_template, + prompt=self.prompt, + prompt_config=self.prompt_config, + ) + + def copy_config(self) -> ProjectConfig: + '''Create a copy of the current agent's configuration for use in a project.''' + return ProjectConfig( + agent_configs={self.agent_config.config_name: self.agent_config} if self.agent_config else {}, + prompt_configs={self.prompt_config.config_name: self.prompt_config} if self.prompt_config else {}, + model_configs={self.model_config.config_name: self.model_config} if self.model_config else {}, + ) + + @classmethod + def init_from_project_config(cls, agent_name: str, project_config: ProjectConfig) -> 'BaseAgent': + '''Create a new instance of the agent from project configuration.''' + agent_config = project_config.agent_configs[agent_name] + agent_type = agent_config.agent_type + model_config = ( + project_config.model_configs[agent_config.llm_config_name] + if agent_config.llm_config_name + else project_config.model_configs["default_chat"] + ) + prompt_config = ( + project_config.prompt_configs[agent_config.prompt_config_name] + if agent_config.prompt_config_name + else PromptConfig() + ) + return cls.get_wrapper(agent_type)( + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config=project_config + ) + + @classmethod + def get_wrapper(cls, agent_type: str) -> Type['BaseAgent']: + '''Retrieve the appropriate agent class wrapper based on the agent type. + + Args: + agent_type (str): + A string that specifies the type of agent for which a wrapper + class is requested. This string is used to look up the + appropriate agent class from the registered agent type registry. + + Returns: + Type['BaseAgent']: + The method returns the appropriate subclass of BaseAgent based on + the provided agent_type. If the agent_type is found in the + class's _type_registry or _registry, it returns the corresponding + class. If not found, it raises a KeyError. + ''' + if agent_type in cls._type_registry: + return cls._type_registry[agent_type] + elif agent_type in cls._registry: + return cls._registry[agent_type] + else: + raise KeyError( + f"Agent Library is missing " + f"{agent_type}, please check your agent type" + ) + + def step( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default", + **kwargs + ) -> Optional[Message]: + '''Process a query and return the agent's response. + + Args: + query (Message): + An instance of the Message class containing the + input query for the agent. + memory_manager (Optional[BaseMemoryManager]): + An optional memory manager instance for managing message history. + session_index (str, default="default"): + A string representing the session index for message tracking and management. + kwargs: Additional keyword arguments for extended functionality. + + Returns: + Optional[Message]: + The final response from the agent as an instance of the Message class, + or None if no response is available. + ''' + session_index = query.session_index or session_index + message = None + # Retrieve the final message from the step_stream generator + for message in self.step_stream( + query, memory_manager, session_index, **kwargs + ): + pass + return message + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default" + ) -> Generator[Message, None, None]: + '''Stream the agent's responses over multiple messages. + + Args: + query (Message): + An instance of the Message class containing the + input query for the agent. + memory_manager (Optional[BaseMemoryManager]): + An optional memory manager instance for managing message history. + session_index (str, default="default"): + A string representing the session index for message tracking and management. + + Returns: + Generator[Message, None, None]: + A generator that yields multiple Message instances as responses to the input query. + ''' + raise NotImplementedError( + f"Agent Wrapper [{type(self).__name__}]" + f" is missing the required `step_stream`" + f" method.", + ) + + def pre_print( + self, + query: Message, + memory_manager: BaseMemoryManager=None, + session_index: str = "default", + **kwargs + ) -> None: + """Pre-print this agent's prompt format. + + Args: + query (Message): + An instance of the Message class containing the + input query for the agent. + memory_manager (Optional[BaseMemoryManager]): + An optional memory manager instance for managing message history. + session_index (str, default="default"): + A string representing the session index for message tracking and management. + """ + session_index = query.session_index or session_index + # Generate the output message before proceeding with the agent action + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # Insert query into history memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # Retrieve memory for the current session + memory = self.get_memory(session_index) + prompt = self.prompt_manager.pre_print(query=query, memory=memory, **kwargs) + + # Displaying the formatted prompt for the agent + title = f"<<<<{self.agent_name}'s prompt>>>>" + print("#"*len(title) + f"\n{title}\n" + "#"*len(title) + f"\n\n{prompt}\n\n") + + def inherit_extrainfo(self, input: Message): + """Incorporate additional information from the last message into the new message.""" + output_message = Message( + role_name=self.agent_name, + role_type="assistant", + session_index=input.session_index, + ) + output_message.update_input(input) + output_message.global_kwargs = copy.deepcopy(input.global_kwargs) # Preserve global args + return output_message + + def registry_actions(self, actions): + '''Register actions related to the LLM model.''' + self.action_list = actions + + def start_action_step(self, message: Message) -> Message: + '''Perform actions before predicting the response from the agent.''' + # (To be implemented) Additional actions can be done here + return message + + def end_action_step(self, message: Message) -> Message: + '''Perform actions after the agent has predicted a response.''' + # (To be implemented) Additional actions can be done here + return message + + def update_memory_manager( + self, + message: Message, + memory_manager: Optional[BaseMemoryManager] = None, + ): + """Update the memory manager with the latest message.""" + if memory_manager: + memory_manager.append(message, self.agent_name) + + def init_history(self, memory: Memory = None) -> Memory: + """Initialize message history.""" + return Memory(messages=[]) + + def update_history(self, message: Message): + """Update the agent's internal history with a new message.""" + self.memory.append(message) + + def append_history(self, message: Message): + """Append a new message to the agent's history.""" + self.memory.append(message) + + def clear_history(self): + """Clear the agent's memory history.""" + self.memory.clear() + self.memory = self.init_history() + + def get_memory( + self, + session_index: str, + memory_manager: Optional[BaseMemoryManager] = None, + ) -> Memory: + """Retrieve the agent's memory for a given session index.""" + if memory_manager: + return memory_manager.get_memory_pool(session_index=session_index) + return self.memory + + def memory_to_format_messages( + self, + attributes: dict[str, Union[any, List[any]]] = {}, + filter_type: Optional[Literal['select', 'filter']] = None, + *, + return_all: bool = True, + content_key: str = "response", + with_tag: bool = False, + format_type: Literal['raw', 'tuple', 'dict', 'str']='raw', + logic: Literal['or', 'and'] = 'and' + ) -> List: + """Format the stored memory into specific message formats based on parameters.""" + kwargs = locals() + kwargs.pop("self") + kwargs.pop("class") + return self.memory.to_format_messages(**kwargs) + + def _get_model(self) -> ModelWrapperBase: + """Retrieve the model wrapper based on the model configuration.""" + if isinstance(self.model_config, ModelConfig): + return get_model(self.model_config) + elif isinstance(self.model_config, TempLLMConfig): + return getChatModelFromConfig(self.model_config) \ No newline at end of file diff --git a/muagent/agents/functioncall_agent.py b/muagent/agents/functioncall_agent.py new file mode 100644 index 0000000..b1c614c --- /dev/null +++ b/muagent/agents/functioncall_agent.py @@ -0,0 +1,237 @@ +from abc import ABCMeta +from pydantic import BaseModel +import os +from typing import ( + List, + Union, + Generator, + Optional, +) + +from loguru import logger + +from ..schemas import ( + Message, + Memory, + PromptConfig, + AgentConfig, + ProjectConfig +) +from .base_agent import BaseAgent +from ..schemas.models import ModelConfig +from ..memory_manager import BaseMemoryManager + +from muagent.connector.schema import LogVerboseEnum + + + +funtioncall_output_template = '''#### RESPONSE OUTPUT FORMAT +**Thoughts:** According the previous context, plan the approach for using the tool effectively. + +**Action Status:** stoped, tool_using or code_executing +Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary. +Use 'tool_using' when the current step in the process involves utilizing a tool to proceed. + +**Action:** + +If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this: +```json +{ + "tool_name": "$TOOL_NAME", + "tool_params": $args +} +``` + +If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this: +```text +The final response or instructions to the user question. +``` +''' + + +class FunctioncallAgent(BaseAgent): + """FunctioncallAgent class that extends the BaseAgent class for + function calling. + + FunctioncallAgent Examples: + .. code-block:: python + from muagent.schemas import Message, Memory + from muagent.agents import FunctioncallAgent + from muagent import get_project_config_from_env + + + # log-level,print prompt和llm predict + os.environ["log_verbose"] = "0" + + AGENT_CONFIGS = { + "codefuse_function_caller": { + "config_name": "codefuse_function_caller", + "agent_type": "FunctioncallAgent", + "agent_name": "codefuse_function_caller", + "llm_config_name": "qwener" + } + } + os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + + project_config = get_project_config_from_env() + tools = ["KSigmaDetector", "MetricsQuery"] + agent = FunctioncallAgent( + agent_name="codefuse_function_caller", + project_config=project_config, + tools=tools + ) + + query_content = "帮我查询下127.0.0.1这个服务器的在10点的数据" + query = Message( + role_name="human", + role_type="user", + content=query_content, + ) + # agent.pre_print(query) + output_message = agent.step(query) + print("### intput ###\n", output_message.input_text) + print("### content ###\n", output_message.content) + print("### step content ###\n", output_message.step_content) + """ + + agent_type: str = "FunctioncallAgent" + """The type of the agent, which is defined as 'FunctioncallAgent'.""" + + agent_id: str + """Unique identifier for the agent.""" + + def __init__( + self, + agent_name: str = "codefuse_function_caller", + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = funtioncall_output_template, + prompt: Optional[str] = None, + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + log_verbose: str = "0", + ): + + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template or funtioncall_output_template, + prompt=prompt, + agents=agents, + tools=tools, + agent_desc=agent_desc, + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config=project_config, + log_verbose=log_verbose + ) + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default", + memory: Optional[Memory] = None, + **kwargs + ) -> Generator[Message, None, None]: + '''agent response from multi-message''' + + session_index = query.session_index or session_index + + # insert query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # transform query into output_message.input_text + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # get memory from self or memory_manager + memory = memory or self.get_memory(session_index) + + # generate prompt by prompt manager + prompt = self.prompt_manager.generate_prompt( + query=output_message, memory=memory, tools=self.tools + ) + + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"{self.agent_name} prompt: {prompt}") + + # predict + model = self._get_model() + content = model.predict(prompt) + + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"{self.agent_name} content: {content}") + + # update infomation + output_message.update_content(content) + + # common parse llm' content to message + output_message = self.prompt_manager.parser(output_message) + + # todo: action step + output_message, observation_message = self.message_util.step_router( + output_message, + session_index=session_index, + tools=self.tools, + **kwargs + ) + # end + output_message = self.end_action_step(output_message) + + # update self_memory and memory pool + self.append_history(output_message) + self.update_memory_manager(output_message, memory_manager) + if observation_message: + self.append_history(observation_message) + self.update_memory_manager(observation_message, memory_manager) + + yield output_message + + def pre_print( + self, + query: Message, + memory_manager: BaseMemoryManager=None, + tools: List[str] = [], + session_index: str = "default" + + ) -> None: + """pre print this agent prompt format""" + session_index = query.session_index or session_index + # + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # insert query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # get memory from self or memory_manager + memory = self.get_memory(session_index) + + prompt = self.prompt_manager.pre_print(query=query, memory=memory, tools=tools or self.tools) + + title = f"<<<<{self.agent_name}'s prompt>>>>" + print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") + + def start_action_step(self, message: Message) -> Message: + '''do action before agent predict ''' + # action_json = self.start_action() + # message["customed_kargs"]["xx"] = action_json + return message + + def end_action_step(self, message: Message) -> Message: + '''do action after agent predict ''' + # action_json = self.end_action() + # message["customed_kargs"]["xx"] = action_json + return message \ No newline at end of file diff --git a/muagent/agents/group_agent.py b/muagent/agents/group_agent.py new file mode 100644 index 0000000..0493b1b --- /dev/null +++ b/muagent/agents/group_agent.py @@ -0,0 +1,227 @@ +from pydantic import BaseModel +from typing import ( + List, + Union, + Generator, + Optional, +) + +from loguru import logger + +from ..schemas import ( + Message, + Memory, + PromptConfig, + AgentConfig, + ProjectConfig +) +from .base_agent import BaseAgent +from ..schemas.models import ModelConfig +from ..memory_manager import BaseMemoryManager + +from muagent.connector.schema import LogVerboseEnum + + + +group_output_template = """#### RESPONSE OUTPUT FORMAT +**Thoughts:** think the reason step by step about why you selecte one role + +**Role:** Select one role from agent names. No other information. +""" + +group_output_template_zh = """#### 响应输出格式 +**思考:** 一步一步思考你选择一个角色的原因 + +**角色:** 从代理名称中选择一个角色。不要包含其他信息。 +""" + +class GroupAgent(BaseAgent): + """GroupAgent class that extends the BaseAgent class for + managing the agent team to complete task. + + GroupAgent Examples: + .. code-block:: python + from muagent.tools import TOOL_SETS + from muagent.schemas import Message + from muagent.agents import BaseAgent + from muagent.project_manager import get_project_config_from_env + + + tools = list(TOOL_SETS) + tools = ["KSigmaDetector", "MetricsQuery"] + role_prompt = "you are a helpful assistant!" + + AGENT_CONFIGS = { + "grouper": { + "agent_type": "GroupAgent", + "agent_name": "grouper", + "agents": ["codefuse_reacter_1", "codefuse_reacter_2"] + }, + "codefuse_reacter_1": { + "agent_type": "ReactAgent", + "agent_name": "codefuse_reacter_1", + "tools": tools, + }, + "codefuse_reacter_2": { + "agent_type": "ReactAgent", + "agent_name": "codefuse_reacter_2", + "tools": tools, + } + } + os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + + # log-level,print prompt和llm predict + os.environ["log_verbose"] = "0" + + # + project_config = get_project_config_from_env() + agent = BaseAgent.init_from_project_config( + "grouper", project_config + ) + + query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" + query = Message( + role_name="human", + role_type="user", + content=query_content, + ) + # agent.pre_print(query) + output_message = agent.step(query) + print("input:", output_message.input_text) + print("content:", output_message.content) + print("step_content:", output_message.step_content) + """ + + agent_type: str = "GroupAgent" + """The type of the agent, which is defined as 'GroupAgent'.""" + + agent_id: str + """Unique identifier for the agent.""" + + def __init__( + self, + agent_name: str = "codefuse_grouper", + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + log_verbose: str = "0", + **kwargs, + ): + + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + input_template=input_template, + output_template=group_output_template, + prompt="", + agents=agents, + tools=tools, + agent_desc=agent_desc, + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config = project_config, + log_verbose=log_verbose, + **kwargs, + ) + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default" + ) -> Generator[Message, None, None]: + '''Stream the agent's responses based on an input multi-message query.''' + + session_index = query.session_index or session_index + + # insert query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # transform query into output_message.input_text + select_message = self.inherit_extrainfo(query) + select_message = self.start_action_step(select_message) + + # get memory from self or memory_manager + memory = self.get_memory(session_index) + + # generate prompt by prompt manager + agents = [self.get_agent_by_name(agent_name) for agent_name in self.agents] + agent_descs = [agent.agent_desc or agent.system_prompt for agent in agents] + prompt = self.prompt_manager.generate_prompt( + query=select_message, memory=memory, + tools=self.tools, agent_names=self.agents, agent_descs=agent_descs, + ) + + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"{self.agent_name} prompt: {prompt}") + + # predict + model = self._get_model() + content = model.predict(prompt) + + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"{self.agent_name} content: {content}") + + # update infomation + select_message.update_content(content) + # common parse llm' content to message + select_message = self.prompt_manager.parser(select_message) + + output_message = None + if select_message.parsed_content.get("Role", "") in self.agents: + agent_name = select_message.parsed_content.get("Role", "") + agent = self.get_agent_by_name(agent_name) + + # update self_memory + self.append_history(select_message) + self.update_memory_manager(select_message, memory_manager) + + # 把除了role以外的信息传给下一个agent + logger.debug(f"{select_message.parsed_content}") + select_message.parsed_content.update( + {k:v for k,v in select_message.parsed_content.items() if k!="Role"} + ) + logger.debug(f"{select_message.parsed_content}") + + # only query to next agent + query_bak = self.inherit_extrainfo(query) + for output_message in agent.step_stream(query_bak, memory_manager, session_index): + yield output_message or select_message + + # + output_message = self.end_action_step(output_message) + + select_message.update_content(output_message.step_content) + select_message.update_parsed_content(output_message.parsed_content) + select_message.update_spec_parsed_content(output_message.spec_parsed_content) + + # update memory pool + self.append_history(output_message) + self.update_memory_manager(select_message, memory_manager) + yield select_message + + def get_agent_by_name(self, agent_name: str) -> BaseAgent: + """new a agent by agent name and project config""" + return self.init_from_project_config(agent_name, self.project_config) + + def start_action_step(self, message: Message) -> Message: + '''Perform any required actions before predicting the response of the agent.''' + # action_json = self.start_action() + # message["customed_kargs"]["xx"] = action_json + return message + + def end_action_step(self, message: Message) -> Message: + '''Perform any required actions after the agent has predicted the response.''' + # action_json = self.end_action() + # message["customed_kargs"]["xx"] = action_json + return message \ No newline at end of file diff --git a/muagent/agents/react_agent.py b/muagent/agents/react_agent.py new file mode 100644 index 0000000..5d5db14 --- /dev/null +++ b/muagent/agents/react_agent.py @@ -0,0 +1,284 @@ +from abc import ABCMeta +from pydantic import BaseModel +from typing import ( + List, + Union, + Generator, + Optional, +) +import copy +from loguru import logger + +from ..schemas import ( + Message, + Memory, + PromptConfig, + AgentConfig, + ProjectConfig +) +from .base_agent import BaseAgent +from ..schemas.models import ModelConfig +from ..schemas.common import ActionStatus +from ..memory_manager import BaseMemoryManager + +from muagent.connector.schema import LogVerboseEnum + + + +react_output_template = '''#### RESPONSE OUTPUT FORMAT +**Thoughts:** According the previous observations, plan the approach for using the tool effectively. + +**Action Status:** stoped, tool_using or code_executing +Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary. +Use 'tool_using' when the current step in the process involves utilizing a tool to proceed. +Use 'code_executing' when the current step requires writing and executing code. + +**Action:** + +If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this: +```json +{ + "tool_name": "$TOOL_NAME", + "tool_params": "$INPUT" +} +``` + +If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this: +```python +Write your running code here +``` + +If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this: +```text +The final response or instructions to the user question. +``` + +**Observation:** Check the results and effects of the executed action. + +... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed) + +**Thoughts:** Conclude the final response to the user question. + +**Action Status:** stoped + +**Action:** The final answer or guidance to the user question. +''' + + +class ReactAgent(BaseAgent): + """ReactAgent class that extends the BaseAgent class for completing task by reacting. + + ReactAgent Examples: + .. code-block:: python + from muagent.tools import TOOL_SETS + from muagent.schemas import Message + from muagent.agents import BaseAgent + from muagent import get_project_config_from_env + + # log-level,print prompt和llm predict + os.environ["log_verbose"] = "0" + + tools = list(TOOL_SETS) + tools = ["KSigmaDetector", "MetricsQuery"] + role_prompt = "you are a helpful assistant!" + + AGENT_CONFIGS = { + "reacter": { + "system_prompt": role_prompt, + "agent_type": "ReactAgent", + "agent_name": "reacter", + "tools": tools, + "llm_config_name": "qwen_chat" + } + } + os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + + # + project_config = get_project_config_from_env() + agent = BaseAgent.init_from_project_config( + "reacter", project_config + ) + + query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" + query = Message( + role_name="human", + role_type="user", + content=query_content, + ) + # agent.pre_print(query) + output_message = agent.step(query) + print("### intput ### ", output_message.input_text) + print("### content ### ", output_message.content) + print("### step content ### ", output_message.step_content) + """ + + agent_type: str = "ReactAgent" + """The type of the agent, which is defined as 'ReactAgent'.""" + + agent_id: str + """Unique identifier for the agent.""" + + def __init__( + self, + agent_name: str = "codefuse_reacter", + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = react_output_template, + prompt: Optional[str] = None, + stop: str = '**Observation:**', + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + chat_turn: int = 3, + log_verbose: str = "0", + ): + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template or react_output_template, + prompt=prompt, + agents=agents, + tools=tools, + agent_desc=agent_desc, + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config=project_config, + log_verbose=log_verbose + ) + # + self.stop = stop + self.chat_turn = chat_turn + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default" + ) -> Generator[Message, None, None]: + '''Stream the agent's responses based on an input multi-message query.''' + + session_index = query.session_index or session_index + step_nums = copy.deepcopy(self.chat_turn) + react_memory = Memory(messages=[]) + + # insert query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # transform query into output_message.input_text + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # get memory from self or memory_manager + memory = self.get_memory(session_index) + + idx = 0 + while step_nums > 0: + output_message.content = output_message.step_content + prompt = self.prompt_manager.generate_prompt( + query=output_message, + memory=memory, + react_memory=react_memory, + tools=self.tools + ) + + try: + model = self._get_model() + content = model.predict(prompt, self.stop) + except Exception as e: + logger.error(f"error : {e}, prompt: {prompt}") + raise RuntimeError(f"error : {e}, prompt: {prompt}") + + output_message.content = content + output_message.step_content += f"\n{content}" + yield output_message + + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"{self.agent_name}, {idx} iteration prompt: {prompt}") + + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"{self.agent_name}, {idx} iteration step_run: {content}") + + output_message = self.prompt_manager.parser(output_message) + # when get finished signal can stop early + if (output_message.action_status == ActionStatus.FINISHED or + output_message.action_status == ActionStatus.STOPPED): + output_message.spec_parsed_contents.append(output_message.spec_parsed_content) + break + # according the output to choose one action for code_content or tool_content + output_message, observation_message = self.message_util.step_router( + output_message, + session_index=session_index, + tools=self.tools, + ) + + # only record content + react_message = copy.deepcopy(output_message) + react_memory.append(react_message) + if observation_message: + react_memory.append(observation_message) + output_message.update_parsed_content(observation_message.parsed_content) + output_message.update_spec_parsed_content(observation_message.parsed_content) + idx += 1 + step_nums -= 1 + yield output_message + + # end + output_message = self.end_action_step(output_message) + + # update self_memory and memory pool + self.append_history(output_message) + self.update_memory_manager(output_message, memory_manager) + yield output_message + + def pre_print( + self, + query: Message, + memory_manager: BaseMemoryManager=None, + tools: List[str] = [], + session_index: str = "default" + + ) -> None: + """pre print this agent prompt format""" + session_index = query.session_index or session_index + react_memory = Memory(messages=[]) + # + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # insert query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # get memory from self or memory_manager + memory = self.get_memory(session_index) + + prompt = self.prompt_manager.pre_print( + query=query, + memory=memory, + tools=tools or self.tools, + react_memory=react_memory + ) + + title = f"<<<<{self.agent_name}'s prompt>>>>" + print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") + + def start_action_step(self, message: Message) -> Message: + '''Perform any required actions before predicting the response of the agent.''' + # action_json = self.start_action() + # message["customed_kargs"]["xx"] = action_json + return message + + def end_action_step(self, message: Message) -> Message: + '''Perform any required actions after the agent has predicted the response.''' + # action_json = self.end_action() + # message["customed_kargs"]["xx"] = action_json + return message \ No newline at end of file diff --git a/muagent/agents/single_agent.py b/muagent/agents/single_agent.py new file mode 100644 index 0000000..741253a --- /dev/null +++ b/muagent/agents/single_agent.py @@ -0,0 +1,205 @@ +from abc import ABCMeta +from pydantic import BaseModel +import os +from typing import ( + List, + Union, + Generator, + Optional, +) + +from loguru import logger + +from ..schemas import ( + Message, + Memory, + PromptConfig, + AgentConfig, + ProjectConfig +) +from .base_agent import BaseAgent +from ..schemas.models import ModelConfig +from ..memory_manager import BaseMemoryManager + +from muagent.connector.schema import LogVerboseEnum + + +class SingleAgent(BaseAgent): + """SingleAgent class that extends the BaseAgent class for simple single-agent tasks. + + FunctioncallAgent Examples: + .. code-block:: python + from muagent.schemas import Message, Memory + from muagent.agents import BaseAgent + from muagent import get_project_config_from_env + + tools = list(TOOL_SETS) + tools = ["KSigmaDetector", "MetricsQuery"] + AGENT_CONFIGS = { + "codefuse_simpler": { + "agent_type": "SingleAgent", + "agent_name": "codefuse_simpler", + "tools": tools, + "llm_config_name": "qwen_chat" + } + } + os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + + project_config = get_project_config_from_env() + agent = BaseAgent.init_from_project_config( + "codefuse_simpler", project_config + ) + + query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" + query = Message( + role_name="human", + role_type="user", + input_text=query_content, + ) + # base_agent.pre_print(query) + output_message = agent.step(query) + print("### intput ###", output_message.input_text) + print("### content ###", output_message.content) + print("### step content ###", output_message.step_content) + """ + + agent_type: str = "SingleAgent" + """The type of the agent, which is defined as 'SingleAgent'.""" + + agent_id: str + """Unique identifier for the agent.""" + + def __init__( + self, + agent_name: str = "codefuse_simpler", + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = None, + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + log_verbose: str = "0", + ): + + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template, + prompt=prompt, + agents=agents, + tools=tools, + agent_desc=agent_desc, + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config=project_config, + log_verbose=log_verbose + ) + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default" + ) -> Generator[Message, None, None]: + '''Stream the agent's responses based on an input multi-message query.''' + + session_index = query.session_index or session_index + + # Insert the received query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # Create an output message containing inherited information from the input query + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # Retrieve memory for the current session, either from self or the memory manager + memory = self.get_memory(session_index) + + # Generate a prompt using the prompt manager + prompt = self.prompt_manager.generate_prompt( + query=output_message, memory=memory, tools=self.tools + ) + + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"{self.agent_name} prompt: {prompt}") + + # Predict the content using the agent's model + model = self._get_model() + content = model.predict(prompt) + + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"{self.agent_name} content: {content}") + + # Update the output message with the predicted content + output_message.update_content(content) + + # Parse the output content into a structured message format + output_message = self.prompt_manager.parser(output_message) + + # Process any actions or observations required based on the output message + output_message, observation_message = self.message_util.step_router( + output_message, + session_index=session_index, + tools=self.tools, + ) + + # Wrap up any action steps + output_message = self.end_action_step(output_message) + + # Update memory with the output message and any observations + self.append_history(output_message) + self.update_memory_manager(output_message, memory_manager) + if observation_message: + self.append_history(observation_message) + self.update_memory_manager(observation_message, memory_manager) + + yield output_message # Yield the constructed output message + + def pre_print( + self, + query: Message, + memory_manager: BaseMemoryManager=None, + tools: List[str] = [], + session_index: str = "default" + ) -> None: + """Prepare and print the prompt format for this agent based on the input query.""" + session_index = query.session_index or session_index + # Prepare an output message with inherited information + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # Insert query into memory for later reference + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # Get the current memory for the session + memory = self.get_memory(session_index) + + # Generate and format the prompt + prompt = self.prompt_manager.pre_print(query=query, memory=memory, tools=tools or self.tools) + + # Display the prompt for this agent + title = f"<<<<{self.agent_name}'s prompt>>>>" + print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") + + def start_action_step(self, message: Message) -> Message: + '''Perform any required actions before predicting the response of the agent.''' + # action_json = self.start_action() + # message["customed_kargs"]["xx"] = action_json + return message + + def end_action_step(self, message: Message) -> Message: + '''Perform any required actions after the agent has predicted the response.''' + # action_json = self.end_action() + # message["customed_kargs"]["xx"] = action_json + return message \ No newline at end of file diff --git a/muagent/agents/task_agent.py b/muagent/agents/task_agent.py new file mode 100644 index 0000000..2aa24da --- /dev/null +++ b/muagent/agents/task_agent.py @@ -0,0 +1,291 @@ +from abc import ABCMeta +from pydantic import BaseModel +import os +from typing import ( + List, + Union, + Generator, + Optional, + Tuple, +) +import copy +from loguru import logger + +from ..schemas import ( + Message, + Memory, + PromptConfig, + AgentConfig, + ProjectConfig +) +from .base_agent import BaseAgent +from ..schemas.models import ModelConfig +from ..memory_manager import BaseMemoryManager +from ..base_configs.prompts import PLAN_EXECUTOR_PROMPT + +from muagent.connector.schema import LogVerboseEnum + + + + +executor_output_template = '''#### RESPONSE OUTPUT FORMAT +**Thoughts:** Considering the session records and task records, decide whether the current step requires the use of a tool or code_executing. +Solve the problem only displaying the thought process necessary for the current step of solving the problem. + +**Action Status:** stoped, tool_using or code_executing +Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary. +Use 'tool_using' when the current step in the process involves utilizing a tool to proceed. +Use 'code_executing' when the current step requires writing and executing code. + +**Action:** + +If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this: +```json +{ + "tool_name": "$TOOL_NAME", + "tool_params": "$INPUT" +} +``` + +If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this: +```python +Write your running code here +``` + +If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this: +```text +The final response or instructions to the user question. +```''' + + +class TaskAgent(BaseAgent): + """TaskAgent class that extends the BaseAgent class for delegaing query into multi task. + + TaskAgent Examples: + .. code-block:: python + + from muagent.schemas import Message + from muagent.agents import BaseAgent + from muagent import get_project_config_from_env + + tools = list(TOOL_SETS) + tools = ["KSigmaDetector", "MetricsQuery"] + role_prompt = "you are a helpful assistant!" + + AGENT_CONFIGS = { + "tasker": { + "system_prompt": role_prompt, + "agent_type": "TaskAgent", + "agent_name": "tasker", + "tools": tools, + "llm_config_name": "qwen_chat" + } + } + os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + + # + project_config = get_project_config_from_env() + agent = BaseAgent.init_from_project_config( + "tasker", project_config + ) + + query_content = "先帮我获取下127.0.0.1这个服务器在10点的数,然后在帮我判断下数据是否存在异常" + query = Message( + role_name="human", + role_type="user", + content=query_content, + ) + # agent.pre_print(query) + output_message = agent.step(query) + print("### intput ###\n", output_message.input_text) + print("### content ###\n", output_message.content) + print("### step content ###\n", output_message.step_content) + """ + + agent_type: str = "TaskAgent" + """The type of the agent, which is defined as 'TaskAgent'.""" + + agent_id: str + """Unique identifier for the agent.""" + + def __init__( + self, + agent_name: str = "codefuse_tasker", + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = executor_output_template, + prompt: Optional[str] = None, + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + do_all_task: bool = True, + log_verbose: str = "0", + ): + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template or executor_output_template, + prompt=prompt, + agents=agents, + tools=tools, + agent_desc=agent_desc, + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config=project_config, + log_verbose=log_verbose + ) + # + self.do_all_task = do_all_task + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default" + ) -> Generator[Message, None, None]: + '''Stream the agent's responses based on an input multi-message query.''' + + session_index = query.session_index or session_index + + # insert query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # transform query into output_message.input_text + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # get memory from self or memory_manager + memory = self.get_memory(session_index) + + # generate prompt by prompt manager + input_text = query.content or output_message.input_text + prompt = PLAN_EXECUTOR_PROMPT.format( + **{"content": input_text.replace("*", "")} + ) + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"{self.agent_name} prompt: {prompt}") + + model = self._get_model() + content = model.predict(prompt) + + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"{self.agent_name} content: {content}") + + plan_message = Message( + session_index=session_index, + role_name="plan_extracter", + role_type="assistant", + content=content, + global_kwargs=query.global_kwargs + ) + plan_message = self.prompt_manager.parser(plan_message) + # process input_quert to plans and plan_step + plan_step = int(plan_message.parsed_content.get("PLAN_STEP", 0)) + plans = plan_message.parsed_content.get("PLAN", [input_text]) + + if self.do_all_task: + # run all tasks step by step + for idx, task_content in enumerate(plans[plan_step:]): + for output_message in self._execute_line( + task_content, output_message, plan_step+idx, session_index + ): + yield output_message + else: + task_content = plans[plan_step] + for output_message in self._execute_line( + task_content, output_message, plan_step+idx, session_index + ): + pass + + # end + output_message = self.end_action_step(output_message) + + # update self_memory and memory pool + self.append_history(output_message) + self.update_memory_manager(output_message, memory_manager) + yield output_message + + def _execute_line( + self, + task_content: str, + output_message: Message, + plan_step, + session_index + ) -> Generator[Tuple[Message, Memory], None, None]: + '''task execute line''' + query = copy.deepcopy(output_message) + query.parsed_content = {"CURRENT_STEP": task_content} + query = self.start_action_step(query) + + # get memory from self or memory_manager + memory = self.get_memory(session_index) + + for output_message in self._run_stream( + query, output_message, memory, session_index + ): + yield output_message + output_message.update_spec_parsed_content( + {**output_message.spec_parsed_content, **{"PLAN_STEP": plan_step}} + ) + yield output_message + + def _run_stream( + self, + query: Message, + output_message: Message, + memory: Memory, + session_index: str + ) -> Generator[Tuple[Message, Memory], None, None]: + '''execute the llm predict by created prompt''' + prompt = self.prompt_manager.generate_prompt( + query=query, + memory=memory, + tools=self.tools, + ) + + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"{self.agent_name} prompt: {prompt}") + + model = self._get_model() + content = model.predict(prompt) + + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"{self.agent_name} content: {content}") + + output_message.update_content(content) + output_message = self.prompt_manager.parser(output_message) + # according the output to choose one action for code_content or tool_content + output_message, observation_message = self.message_util.step_router( + output_message, session_index=session_index, + tools=self.tools + ) + react_message = copy.deepcopy(output_message) + self.append_history(react_message) + # task_memory.append(react_message) + if observation_message: + # task_memory.append(observation_message) + self.append_history(observation_message) + output_message.update_parsed_content(observation_message.parsed_content) + output_message.update_spec_parsed_content(observation_message.parsed_content) + yield output_message + + def start_action_step(self, message: Message) -> Message: + '''Perform any required actions before predicting the response of the agent.''' + # action_json = self.start_action() + # message["customed_kargs"]["xx"] = action_json + return message + + def end_action_step(self, message: Message) -> Message: + '''Perform any required actions after the agent has predicted the response.''' + # action_json = self.end_action() + # message["customed_kargs"]["xx"] = action_json + return message \ No newline at end of file diff --git a/muagent/agents/user_agent.py b/muagent/agents/user_agent.py new file mode 100644 index 0000000..3393d33 --- /dev/null +++ b/muagent/agents/user_agent.py @@ -0,0 +1,140 @@ +from abc import ABCMeta +from pydantic import BaseModel +import os +from typing import ( + List, + Union, + Generator, + Optional, +) + +from loguru import logger + +from ..schemas import ( + Message, + Memory, + PromptConfig, + AgentConfig, + ProjectConfig +) +from .base_agent import BaseAgent +from ..schemas.models import ModelConfig +from ..memory_manager import BaseMemoryManager + +from muagent.connector.schema import LogVerboseEnum + + +class UserAgent(BaseAgent): + """UserAgent class that extends the BaseAgent class for simulating user' response.""" + + agent_type: str = "UserAgent" + """The type of the agent, which is defined as 'UserAgent'.""" + + agent_id: str + """Unique identifier for the agent.""" + + def __init__( + self, + agent_name: str = "codefuse_user", + system_prompt: str = "", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = None, + agents: List[str] = [], + tools: List[str] = [], + agent_desc: str = "", + *, + agent_config: Optional[AgentConfig] = None, + model_config: Optional[ModelConfig] = None, + prompt_config: Optional[PromptConfig] = PromptConfig(), + project_config: Optional[ProjectConfig] = None, + # + log_verbose: str = "0", + ): + + super().__init__( + agent_name=agent_name, + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template, + prompt=prompt, + agents=agents, + tools=tools, + agent_desc=agent_desc, + agent_config=agent_config, + model_config=model_config, + prompt_config=prompt_config, + project_config=project_config, + log_verbose=log_verbose + ) + + def step_stream( + self, + query: Message, + memory_manager: Optional[BaseMemoryManager]=None, + session_index: str = "default" + ) -> Generator[Message, None, None]: + '''Stream the agent's responses based on an input multi-message query.''' + + session_index = query.session_index or session_index + + # insert query into memory + self.append_history(query) + self.update_memory_manager(query, memory_manager) + + # transform query into output_message.input_text + output_message = self.inherit_extrainfo(query) + output_message = self.start_action_step(output_message) + + # get memory from self or memory_manager + memory = self.get_memory(session_index) + + # generate prompt by prompt manager + prompt = self.prompt_manager.generate_prompt( + query=output_message, memory=memory, tools=self.tools + ) + + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"{self.agent_name} prompt: {prompt}") + + # predict + content = input("please answer: \n") + + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"{self.agent_name} content: {content}") + + # update infomation + output_message.update_content(content) + + # common parse llm' content to message + output_message = self.prompt_manager.parser(output_message) + + # todo: action step + output_message, observation_message = self.message_util.step_router( + output_message, + session_index=session_index, + tools=self.tools, + ) + # end + output_message = self.end_action_step(output_message) + + # update self_memory and memory pool + self.append_history(output_message) + self.update_memory_manager(output_message, memory_manager) + if observation_message: + self.append_history(observation_message) + self.update_memory_manager(observation_message, memory_manager) + + yield output_message + + def pre_print( + self, + query: Message, + memory_manager: BaseMemoryManager=None, + tools: List[str] = [], + session_index: str = "default" + + ) -> None: + """pre print this agent prompt format""" + title = f"<<<<{self.agent_name}'s prompt>>>>" + print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{query.content}\n\n") diff --git a/muagent/agents/util.py b/muagent/agents/util.py new file mode 100644 index 0000000..e69de29 diff --git a/muagent/base_configs/env_config.py b/muagent/base_configs/env_config.py index 4b78388..765944c 100644 --- a/muagent/base_configs/env_config.py +++ b/muagent/base_configs/env_config.py @@ -12,19 +12,19 @@ # SOURCE_PATH = os.environ.get("SOURCE_PATH", None) or os.path.join(executable_path, "sources") # 知识库默认存储路径 -KB_ROOT_PATH = os.environ.get("KB_ROOT_PATH", None) or os.path.join(executable_path, "knowledge_base") +KB_ROOT_PATH = os.environ.get("KB_ROOT_PATH", None) or os.path.join(executable_path, "data/knowledge_base") # 代码库默认存储路径 -CB_ROOT_PATH = os.environ.get("CB_ROOT_PATH", None) or os.path.join(executable_path, "code_base") +CB_ROOT_PATH = os.environ.get("CB_ROOT_PATH", None) or os.path.join(executable_path, "data/code_base") # # nltk 模型存储路径 # NLTK_DATA_PATH = os.environ.get("NLTK_DATA_PATH", None) or os.path.join(executable_path, "nltk_data") # 代码存储路径 -JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(executable_path, "jupyter_work") +JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(executable_path, "data/jupyter_work") -# WEB_CRAWL存储路径 -WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base") +# # WEB_CRAWL存储路径 +# WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base") # NEBULA_DATA存储路径 NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data") @@ -32,7 +32,7 @@ # CHROMA 存储路径 CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data") -for _path in [LOG_PATH, KB_ROOT_PATH, CB_ROOT_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]: +for _path in [LOG_PATH, KB_ROOT_PATH, CB_ROOT_PATH, JUPYTER_WORK_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]: if not os.path.exists(_path) and int(os.environ.get("do_create_dir", True)): os.makedirs(_path, exist_ok=True) @@ -83,8 +83,8 @@ # 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 # Mac 可能存在无法使用normalized_L2的问题,因此调整SCORE_THRESHOLD至 0~1100 -FAISS_NORMALIZE_L2 = True if system_name in ["Linux", "Windows"] else False -SCORE_THRESHOLD = 1 if system_name in ["Linux", "Windows"] else 1100 +FAISS_NORMALIZE_L2 = True if system_name in ["Darwin", "Linux", "Windows"] else False +SCORE_THRESHOLD = 1 if system_name in ["Darwin", "Linux", "Windows"] else 1100 # 搜索引擎匹配结题数量 SEARCH_ENGINE_TOP_K = os.environ.get("SEARCH_ENGINE_TOP_K") or 5 diff --git a/muagent/base_configs/prompts/functioncall_template_prompt.py b/muagent/base_configs/prompts/functioncall_template_prompt.py new file mode 100644 index 0000000..812e481 --- /dev/null +++ b/muagent/base_configs/prompts/functioncall_template_prompt.py @@ -0,0 +1,36 @@ + +FUNCTION_CALL_PROMPT_en = """You have access to the following functions: + +{tool_desc} + +To call a function, please respond with JSON for a function call. + +Respond in the format [{"name": function name, "arguments": dictionary of argument name and its value}]. +""" + +FC_AUTO_PROMPT_en = """ +The function can be called zero or multiple according to your needs. +""" + + +FC_REQUIRED_PROMPT_en = """ +You must call a function as least. +""" + +FC_PARALLEL_PROMPT_en = """ +The function can be called in parallel. +""" + + +FC_RESPONSE_PROMPT_en = """## Response Ouput +Response the function calls by formatting the in JSON. The format should be: + +```json +[ +{ + "name": function name, + "arguments": dictionary of argument name and its value +} +] +``` +""" \ No newline at end of file diff --git a/muagent/base_configs/prompts/intention_template_prompt.py b/muagent/base_configs/prompts/intention_template_prompt.py index 4e171fd..64d2c2b 100644 --- a/muagent/base_configs/prompts/intention_template_prompt.py +++ b/muagent/base_configs/prompts/intention_template_prompt.py @@ -74,6 +74,7 @@ ## 输出格式 最相关意图对应的数字(第一个意图对应数字1){extra}。 +其他任何内容都是不允许的。 {example} ## 用户询问 """ @@ -141,16 +142,8 @@ def get_intention_prompt( name='整体计划查询', tag='allPlan' ) INTENTION_NEXTSTEP = IntentionInfo( - description='用户询问某个问题或方案中某一个特定步骤。通常会提及“下一步”、“具体操作”等。', - name='某一步任务查询', tag='nextStep' -) -INTENTION_SEVERALSTEPS = IntentionInfo( - description='用户询问某个问题或方案中其中某几个步骤。', - name='某几步任务查询', tag='severalSteps' -) -INTENTION_BACKGROUND = IntentionInfo( - description='用户询问某个问题或方案的背景知识,规则以及流程介绍等。', - name='背景查询', tag='background' + description='用户询问某个问题或方案的特定步骤,通常会提及“下一步”、“具体操作”等。', + name='下一步任务查询', tag='nextStep' ) INTENTION_CHAT = IntentionInfo( description='用户询问的内容与当前的技术问题或解决方案无关,更多是出于兴趣或社交性质的交流。', @@ -179,16 +172,13 @@ def get_intention_prompt( } ) -INTENTIONS_CONSULT_WHICH = (INTENTION_ALLPLAN, INTENTION_NEXTSTEP, INTENTION_SEVERALSTEPS, INTENTION_BACKGROUND, INTENTION_CHAT) +INTENTIONS_CONSULT_WHICH = (INTENTION_ALLPLAN, INTENTION_NEXTSTEP, INTENTION_CHAT) CONSULT_WHICH_PROMPT = get_intention_prompt( intentions=INTENTIONS_CONSULT_WHICH, examples={ '如何组织一次活动?': INTENTION_ALLPLAN, '系统升级的整个流程是怎样的?': INTENTION_ALLPLAN, '为什么我没有收到红包?请告诉我方案': INTENTION_ALLPLAN, - '如果我想学习一门新语言,第一步我需要先做些什么?': INTENTION_NEXTSTEP, - '项目开发中代码开发完成后需要经过哪几步测试才能发布到生产呢?': INTENTION_SEVERALSTEPS, - '请问下狼人杀游戏中猎人的主要职责是什么?': INTENTION_BACKGROUND, '听说你们采用了新工具,能讲讲它的特点吗?': INTENTION_CHAT } ) diff --git a/muagent/connector/memory/hierarchical_memory_manager.py b/muagent/connector/memory/hierarchical_memory_manager.py index 180dc35..7ccac84 100644 --- a/muagent/connector/memory/hierarchical_memory_manager.py +++ b/muagent/connector/memory/hierarchical_memory_manager.py @@ -15,7 +15,8 @@ from muagent.connector.memory_manager import BaseMemoryManager from muagent.llm_models import * from muagent.base_configs.env_config import KB_ROOT_PATH -from muagent.orm import table_init +# from muagent.orm import table_init +from muagent.db_handler import table_init from muagent.utils.common_utils import * diff --git a/muagent/connector/memory_manager.py b/muagent/connector/memory_manager.py index e536466..0c7dedf 100644 --- a/muagent/connector/memory_manager.py +++ b/muagent/connector/memory_manager.py @@ -19,7 +19,8 @@ from muagent.llm_models.llm_config import EmbedConfig, LLMConfig from muagent.retrieval.utils import load_embeddings_from_path from muagent.utils.common_utils import * -from muagent.orm import table_init +# from muagent.orm import table_init +from muagent.db_handler import table_init from muagent.base_configs.env_config import KB_ROOT_PATH # from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD # from configs.model_config import embedding_model_dict diff --git a/muagent/connector/schema/general_schema.py b/muagent/connector/schema/general_schema.py index d87c373..350f72c 100644 --- a/muagent/connector/schema/general_schema.py +++ b/muagent/connector/schema/general_schema.py @@ -26,7 +26,7 @@ class ActionStatus(Enum): def __eq__(self, other): if isinstance(other, str): - return self.value.lower() == other.lower() + return self.value.strip().lower() == other.strip().lower() return super().__eq__(other) @@ -198,6 +198,10 @@ def __le__(self, other): @classmethod def ge(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']): return enum_value <= other + + @classmethod + def le(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']): + return enum_value <= other class Task(BaseModel): diff --git a/muagent/connector/utils.py b/muagent/connector/utils.py index ea2377e..68ff47c 100644 --- a/muagent/connector/utils.py +++ b/muagent/connector/utils.py @@ -1,3 +1,6 @@ +from typing import ( + Dict, +) import re, copy, json from loguru import logger @@ -70,8 +73,9 @@ def parse_section_to_dict(text, section_name): def parse_text_to_dict(text): - # Define a regular expression pattern to capture the key and value - main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)" + """through a regular expression pattern to capture the key and value""" + # main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)" + main_pattern = r'\*\*([^*]+):\*\*\s*(.*?)(?=\*\*([^*]+):\*\*|$)' list_pattern = r'```python\n(.*?)```' plan_pattern = r'(\[\s*.*?\s*\])' @@ -79,7 +83,10 @@ def parse_text_to_dict(text): main_matches = re.findall(main_pattern, text, re.DOTALL) # Convert main matches to a dictionary - parsed_dict = {key.strip(): value.strip() for key, value in main_matches} + parsed_dict = { + v[0].strip(): v[1].strip() + for v in main_matches + } for k, v in parsed_dict.items(): for pattern in [list_pattern, plan_pattern]: @@ -94,12 +101,13 @@ def parse_text_to_dict(text): return parsed_dict -def parse_dict_to_dict(parsed_dict) -> dict: +def parse_dict_to_dict(parsed_dict: Dict) -> Dict: + """through a regular expression pattern to decode ```python/json/java``` into fragment""" code_pattern = r'```python\n(.*?)```' tool_pattern = r'```json\n(.*?)```' java_pattern = r'```java\n(.*?)```' - pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern} + pattern_dict = {"python": code_pattern, "json": tool_pattern, "java": java_pattern} spec_parsed_dict = copy.deepcopy(parsed_dict) for key, pattern in pattern_dict.items(): for k, text in parsed_dict.items(): diff --git a/muagent/db_handler/__init__.py b/muagent/db_handler/__init__.py index aebb417..f475467 100644 --- a/muagent/db_handler/__init__.py +++ b/muagent/db_handler/__init__.py @@ -8,10 +8,28 @@ from .graph_db_handler import NebulaHandler, NetworkxHandler, AliYunSLSHandler, GeaBaseHandler, GBHandler from .vector_db_handler import LocalFaissHandler, TbaseHandler, ChromaHandler - +from .db import _engine, Base __all__ = [ "GBHandler", "NebulaHandler", "NetworkxHandler", "GeaBaseHandler", "ChromaHandler", "TbaseHandler", "LocalFaissHandler", "AliYunSLSHandler" -] \ No newline at end of file +] + + +def create_tables(): + Base.metadata.create_all(bind=_engine) + +def reset_tables(): + Base.metadata.drop_all(bind=_engine) + create_tables() + + +def check_tables_exist(table_name) -> bool: + table_exist = _engine.dialect.has_table(_engine.connect(), table_name, schema=None) + return table_exist + +def table_init(): + if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")) or \ + (not check_tables_exist ("code_base")): + create_tables() diff --git a/muagent/orm/db.py b/muagent/db_handler/db.py similarity index 100% rename from muagent/orm/db.py rename to muagent/db_handler/db.py diff --git a/muagent/db_handler/graph_db_handler/nebula_handler.py b/muagent/db_handler/graph_db_handler/nebula_handler.py index 9d94180..eba6e95 100644 --- a/muagent/db_handler/graph_db_handler/nebula_handler.py +++ b/muagent/db_handler/graph_db_handler/nebula_handler.py @@ -57,7 +57,6 @@ def __init__(self,gb_config : GBConfig = None): self.connection_pool = ConnectionPool() if gb_config == None: - self.connection_pool.init([('graphd', '9669')], config) self.username = '' or 'root' self.nb_pw = '' or 'nebula' @@ -116,7 +115,7 @@ def execute_cypher(self, cypher: str, space_name: str = '',ignore_log: bool = Fa if ignore_log == False: if resp.is_succeeded(): - #logger.info(f"Successfully executed Cypher query: {cypher}") + # logger.info(f"Successfully executed Cypher query: {cypher}") pass @@ -165,18 +164,20 @@ def execute_cypher_return_status(self, cypher: str, space_name: str = '', format errorMessage=resp.error_msg(), errorCode=resp.error_code(), ) - def add_hosts(self, hostname, port): + while not self.is_host_connected(hostname, port): with self.connection_pool.session_context(self.username, self.nb_pw) as session: cypher = f'ADD HOSTS "{hostname}":{port}' resp = session.execute(cypher) - return resp + print('增加NebulaGraph Storage主机中,等待20秒') + time.sleep(20) + return def close_connection(self): self.connection_pool.close() - def create_space(self, space_name: str, vid_type: str = 'FIXED_STRING(32)', comment: str = ''): + def create_space(self, space_name: str, vid_type: str = 'FIXED_STRING(1024)', comment: str = ''): ''' create space @param space_name: cannot startwith number @@ -277,6 +278,20 @@ def show_edge_type(self): resp = self.execute_cypher(cypher, self.space_name) return resp + def is_host_connected(self, hostname, port): + # 查询系统表以检查主机的连接状态 + with self.connection_pool.session_context(self.username, self.nb_pw) as session: + cypher = 'SHOW HOSTS' + resp = session.execute(cypher) + + resp = resp.as_primitive() + # 假设返回结果中包含一个名为 'host' 的字段 + for i in resp: + if hostname==i['Host'] and port == i['Port'] and i["Status"] =="ONLINE": + return True + + return False + def delete_edge_type(self, edge_type_name: str): cypher = f'DROP EDGE {edge_type_name}' return self.execute_cypher(cypher, self.space_name) diff --git a/muagent/db_handler/vector_db_handler/local_faiss_handler.py b/muagent/db_handler/vector_db_handler/local_faiss_handler.py index f18605c..e43d429 100644 --- a/muagent/db_handler/vector_db_handler/local_faiss_handler.py +++ b/muagent/db_handler/vector_db_handler/local_faiss_handler.py @@ -1,11 +1,15 @@ from loguru import logger -from typing import List +from typing import List, Union from functools import lru_cache import os, shutil from langchain.embeddings.base import Embeddings from langchain_community.docstore.document import Document + +from muagent.models import get_model +from muagent.schemas.models import ModelConfig + from muagent.utils.server_utils import torch_gc from muagent.retrieval.base_service import SupportedVSType from muagent.retrieval.faiss_m import FAISS @@ -23,17 +27,20 @@ class LocalFaissHandler: def __init__( self, - embed_config: EmbedConfig, + embed_config: Union[EmbedConfig, ModelConfig], vb_config: VBConfig = None ): self.vb_config = vb_config self.embed_config = embed_config - self.embeddings = load_embeddings_from_path( - self.embed_config.embed_model_path, - self.embed_config.model_device, - self.embed_config.langchain_embeddings - ) + if isinstance(self.embed_config, ModelConfig): + self.embeddings = get_model(self.embed_config) + else: + self.embeddings = load_embeddings_from_path( + self.embed_config.embed_model_path, + self.embed_config.model_device, + self.embed_config.langchain_embeddings + ) # INIT self.search_index: FAISS = None diff --git a/muagent/ekg_project.py b/muagent/ekg_project.py new file mode 100644 index 0000000..d5b3949 --- /dev/null +++ b/muagent/ekg_project.py @@ -0,0 +1,659 @@ +from typing import ( + Union, + Sequence, + Literal, + Mapping, + Optional, + Dict, + List +) +from pydantic import BaseModel +import os +import json +from loguru import logger +import concurrent.futures +import time +import random + +from .llm_models import LLMConfig, EmbedConfig +from .schemas.db import TBConfig, GBConfig +from .schemas.models import ModelConfig +from .schemas import EKGProjectConfig, Message, Memory, AgentConfig, PromptConfig +from .schemas.common import GNode, GEdge +from .db_handler import * +from .agents import FunctioncallAgent + +from .service.utils import decode_biznodes, encode_biznodes + +# from .connector.schema import Memory, Message +from .connector.memory_manager import TbaseMemoryManager +from .service.ekg_construct.ekg_construct_base import EKGConstructService +from .service.ekg_inference import IntentionRouter +from .service.ekg_reasoning.src.graph_search.graph_search_main import main as reasoning + + +class LingSiResponse(BaseModel): + '''lingsi的输出值, 算法的输入值 + The following is an example: + + .. code-block:: python + + from xx import LingSiResponse + ls_resp = LingSiResponse( + observation={'content': '一起来玩谁是卧底'}, + sessionId='default_sessionId', + scene="UNDERCOVER", + ) + + ls_resp = LingSiResponse( + observation={'toolResponse': '我的单词是一种工业品'}, + currentNodeId='剧本杀/谁是卧底/智能交互/开始新一轮的讨论' + sessionId='default_sessionId', + scene="UNDERCOVER", + type='reactExecution' + ) + ''' + sessionId: str + """The session index""" + + currentNodeId: Optional[str] = None + """The last node index, the first is null""" + + type: Optional[str] = None + """The last execute type, the first is null""" + + agentName:Optional[str]=None + """The agent name from last node output, the first is null""" + + scene: Literal["UNDERCOVER", "WEREWOLF" , "NEXA" ] = "NEXA" + """The scene type of this task.""" + + observation: Optional[Union[str,Dict]] # jsonstr + '''last observation from last node + .. code-block:: python + observation: Literal["content", "tool_response"] + ''' + + userAnswer: Optional[str]=None + """no use""" + + startRootNodeId: Optional[str] = '' + """The default team root id""" + + intentionData: Optional[Union[List,str] ] = None + """equal query, only once at first""" + + startFromRoot: Literal['True', 'false', 'true', 'False'] = 'True' + """""" + + intentionRule: Optional[Union[List,str]]= ["nlp"] + """no use""" + + +class QuestionContent(BaseModel): + ''' + {'question': '请玩家根据当前情况发言', 'candidate': None } + ''' + question:str + candidate:Optional[str]=None + +class QuestionDescription(BaseModel): + ''' + {'questionType': 'essayQuestion', + 'questionContent': {'question': '请玩家根据当前情况发言','candidate': None }} + ''' + questionType: Literal["essayQuestion", "multipleChoice"] = "essayQuestion" + questionContent: QuestionContent + +class ToolPlanOneStep(BaseModel): + ''' + tool_plan_one_step = {'toolDescription': '请用户回答', + 'currentNodeId': nodeId, + 'memory': None, + 'type': 'userProblem', + 'questionDescription': {'questionType': 'essayQuestion', + 'questionContent': {'question': '请玩家根据当前情况发言', + 'candidate': None }}} + ''' + currentNodeId: Optional[str] = None + """from last node index""" + + toolDescription:str + """the input for functioncalling""" + + currentNodeInfo: Optional[str] = None + """equal agent name""" + + memory: Optional[str] = None + """memory""" + + questionDescription: Optional[QuestionDescription]=None + """反问的过程""" + + type: Optional[Literal["onlyTool", "userProblem", "reactExecution"]] = None + """request type""" + + +class ResToLingsi(BaseModel): + '''lingsi的输入值, 算法的输出值 + The following is an example: + + .. code-block:: python + + from xx import ResToLingsi + resp_to_ls = ResToLingsi( + sessionId = "default_sessionId", + type="onlyTool", + summary=None, + toolPlan=ToolPlan( + toolDescription="agent_李静", + currentNodeId='26921eb05153216c5a1f585f9d318c77%%@@#agent_李静', + currentNodeInfo='agent_李静', + memory="", + questionDescription=None, + type="reactExecution" + userInteraction='开始新一轮的讨论

**主持人:**
各位玩家请注意,现在所有玩家均存活,我们将按照座位顺序进行发言。发言顺序为1号李静、2号张伟、3号人类玩家、4号王鹏。现在,请1号李静开始发言。' + intentionRecognitionSituation=None, + ) + ''' + sessionId: str + """session index from last node output""" + + toolPlan:Optional[List[ToolPlanOneStep]] = None + """""" + + userInteraction:Optional[str]=None + """if userInteraction, yield""" + + summary: Optional[str] = None + """if summary, end, yield""" + + type: Optional[str] = None + """no use""" + + intentionRecognitionSituation: Optional[str]=None + """no use""" + + +def get_ekg_project_config_from_env( + model_configs: Optional[Dict[str, Union[LLMConfig, ModelConfig]]] = None, + embed_configs: Optional[Dict[str, Union[EmbedConfig, ModelConfig]]] = None, + db_configs: Optional[Mapping[str, Union[GBConfig, TBConfig]]] = None, + agent_configs: Optional[Mapping[str, AgentConfig]] = None, + prompt_configs: Optional[Mapping[str,PromptConfig]] = None, +) -> EKGProjectConfig: + """""" + project_configs = { + "model_configs": {}, + "embed_configs": {}, + "db_configs": {}, + "agent_configs": {}, + "prompt_configs": {}, + } + # + db_config_name_to_class = { + "gb_config": GBConfig, + "tb_config": TBConfig, + } + # init model configs + if model_configs: + for k, v in model_configs.items(): + if isinstance(v, LLMConfig) or isinstance(v, ModelConfig): + project_configs["model_configs"][k] = v + else: + try: + project_configs["model_configs"][k] = ModelConfig(**v) + except: + project_configs["model_configs"][k] = LLMConfig(**v) + elif "model_configs".upper() in os.environ: + _model_configs = json.loads(os.environ["model_configs".upper()]) + for k, v in _model_configs.items(): + try: + project_configs["model_configs"][k] = ModelConfig(**v) + except: + project_configs["model_configs"][k] = LLMConfig(**v) + + chat_list = [_type for _type in project_configs["model_configs"].keys() if "chat" in _type] + embedding_list = [_type for _type in project_configs["model_configs"].keys() if "embedding" in _type] + if chat_list: + model_type = random.choice(chat_list) + default_model_config = project_configs["model_configs"][model_type] + project_configs["model_configs"]["default_chat"] = default_model_config + os.environ["DEFAULT_MODEL_TYPE"] = model_type + os.environ["DEFAULT_MODEL_NAME"] = default_model_config.model_name + os.environ["DEFAULT_API_KEY"] = default_model_config.api_key or "" + os.environ["DEFAULT_API_URL"] = default_model_config.api_url or "" + + if embedding_list: + model_type = random.choice(embedding_list) + default_model_config = project_configs["model_configs"][model_type] + project_configs["model_configs"]["default_embed"] = default_model_config + project_configs[k] = v + + # init embedding configs + if embed_configs: + for k, v in embed_configs.items(): + if isinstance(v, EmbedConfig) or isinstance(v, ModelConfig): + project_configs["embed_configs"][k] = v + else: + try: + project_configs["embed_configs"][k] = EmbedConfig(**v) + except: + project_configs["embed_configs"][k] = ModelConfig(**v) + elif "embed_configs".upper() in os.environ: + embed_configs = json.loads(os.environ["embed_configs".upper()]) + for k, v in embed_configs.items(): + if isinstance(v, EmbedConfig) or isinstance(v, ModelConfig): + project_configs["embed_configs"][k] = v + else: + try: + project_configs["embed_configs"][k] = EmbedConfig(**v) + except: + project_configs["embed_configs"][k] = ModelConfig(**v) + + # init db configs + db_configs = db_configs or json.loads(os.environ["DB_CONFIGS"]) + for k in ["tb_config", "gb_config"]: + if db_configs and k not in db_configs: + raise KeyError( + f"EKG must have {k}. " + f"please check your env config or input." + ) + else: + project_configs["db_configs"][k] = db_config_name_to_class[k]( + **db_configs[k]) + + # init agent configs + if "AGENT_CONFIGS" in os.environ: + agent_configs = agent_configs or json.loads(os.environ["AGENT_CONFIGS"]) + agent_configs = { + kk: AgentConfig(**vv) + for kk, vv in agent_configs.items() + } + project_configs["agent_configs"] = agent_configs + else: + logger.warning( + f"Cant't init any AGENT_CONFIGS in this env." + ) + + # init prompt configs + if "PROMPT_CONFIGS" in os.environ: + prompt_configs = prompt_configs or json.loads(os.environ["PROMPT_CONFIGS"]) + prompt_configs = { + kk: PromptConfig(**vv) + for kk, vv in prompt_configs.items() + } + project_configs["prompt_configs"] = prompt_configs + else: + logger.warning( + f"Cant't init any AGENT_CONFIGS in this env." + ) + + + return EKGProjectConfig(**project_configs) + + +class EKG: + """Class to represent and manage the EKG project.""" + + def __init__( + self, + tb_config: Optional[TBConfig] = None, + gb_config: Optional[GBConfig] = None, + embed_config: Union[ModelConfig, EmbedConfig] = None, + llm_config: Union[ModelConfig, LLMConfig] = None, + project_config: EKGProjectConfig = None, + agents: List[str] = [], + tools: List[str] = [], + *, + initialize_space = True + ): + + # Initialize various configuration settings for the EKG project. + self.tb_config = tb_config + self.gb_config = gb_config + self.embed_config = embed_config + self.llm_config = llm_config + self.project_config = project_config + self.agents = agents + self.tools = tools + + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) + self.futures = [] + # Set whether to initialize space + self.initialize_space = initialize_space + self.init_from_project() + + @classmethod + def from_project(cls, project_config: EKGProjectConfig, initialize_space=False) -> 'EKG': + """Create an instance of EKG from a project configuration.""" + return cls(project_config=project_config, initialize_space=initialize_space) + + def init_from_project(self): + """Initialize settings from the provided project configuration.""" + + # Setup the time-based configuration + if self.project_config and self.project_config.db_configs: + self.tb_config = self.tb_config or \ + self.project_config.db_configs.get("tb_config") + elif self.tb_config: + pass + else: + raise KeyError( + f"EKG Project must have 'tb_config' in " + f"db_configs" + ) + + # Setup the graph-based configuration + if self.project_config and self.project_config.db_configs: + self.gb_config = self.gb_config or \ + self.project_config.db_configs.get("gb_config") + elif self.gb_config: + pass + else: + raise KeyError( + f"EKG Project must have 'gb_config' in " + f"db_configs" + ) + + # Setup embedding configuration + if self.project_config and self.project_config.embed_configs: + if "default" not in self.project_config.embed_configs: + raise KeyError( + f"EKG Project must have key=default in " + f"embed_configs" + ) + self.embed_config = self.project_config.embed_configs.get("default") + + # Setup LLM configuration and environment variables + if self.project_config and self.project_config.model_configs: + # if "default_chat" not in self.project_config.llm_configs: + # raise KeyError( + # f"EKG Project must have key=default in " + # f"llm_configs" + # ) + + # os.environ["API_BASE_URL"] = self.project_config.llm_configs["default"].api_base_url + # os.environ["OPENAI_API_KEY"] = self.project_config.llm_configs["default"].api_key + # os.environ["model_name"] = self.project_config.llm_configs["default"].model_name + # os.environ["model_engine"] = self.project_config.llm_configs["default"].model_engine + # os.environ["llm_temperature"] = self.project_config.llm_configs["default"].temperature + self.llm_config = self.project_config.model_configs.get("default_chat") + self.llm_config = LLMConfig( + model_name=os.environ["model_name"], + model_engine=os.environ["model_engine"], + api_key=os.environ["OPENAI_API_KEY"], + api_base_url=os.environ["API_BASE_URL"], + ) + + # Ensure 'codefuser' config exists + if "codefuser" not in self.project_config.model_configs: + raise KeyError( + f"EKG Project must have key=codefuser in " + f"llm_configs" + ) + + os.environ["gpt4-API_BASE_URL"] = self.project_config.model_configs["codefuser"].api_base_url + os.environ["gpt4-OPENAI_API_KEY"] = self.project_config.model_configs["codefuser"].api_key + os.environ["gpt4-model_name"] = self.project_config.model_configs["codefuser"].model_name + os.environ["gpt4-model_engine"] = self.project_config.model_configs["codefuser"].model_engine + os.environ["gpt4-llm_temperature"] = self.project_config.model_configs["codefuser"].temperature + + self._init_ekg_construt_service() # Initialize the EKG construction service + self._init_memory_manager() # Initialize the memory manager + self._init_intention_router() # Initialize the intention router + + def _init_ekg_construt_service(self): + """Initialize the service responsible for building the EKG graph.""" + self.ekg_construct_service = EKGConstructService( + embed_config=self.embed_config, + llm_config=self.llm_config, + tb_config=self.tb_config, + gb_config=self.gb_config, + initialize_space=self.initialize_space + ) + + def _init_memory_manager(self): + """Initialize the memory manager with the appropriate configuration.""" + tb = TbaseHandler( + self.tb_config, + self.tb_config.index_name, + definition_value=self.tb_config.extra_kwargs.get( + "memory_definition_value") + ) + + self.memory_manager = TbaseMemoryManager( + unique_name="EKG", + embed_config=self.embed_config, + llm_config=self.llm_config, + tbase_handler=tb, # Use the Tbase handler for database management + use_vector=False + ) + + def _init_intention_router(self): + """Initialize the routing mechanism for intentions within the EKG project.""" + self.intention_router = IntentionRouter( + self.ekg_construct_service.model, + self.ekg_construct_service.gb, + self.ekg_construct_service.tb, + self.embed_config + ) + + def __call__(self): + """Call method for EKG class instance (to be implemented).""" + pass + + def add_node( + self, + node: Union[Dict, GNode], + *, + teamid: str = "default", + ) -> None: + """Add a node to the EKG graph.""" + gnode = GNode(**node) if isinstance(node, Dict) else node + gnodes, _ = decode_biznodes([gnode]) # Decode the business nodes + self.ekg_construct_service.add_nodes(gnodes, teamid) # Add nodes to the construct service + + def add_edge( + self, + start_id: str, + end_id: str, + *, + teamid: str = "", + ) -> None: + """Add an edge between two nodes in the EKG graph.""" + start_node = self.ekg_construct_service.get_node_by_id(start_id) + end_node = self.ekg_construct_service.get_node_by_id(end_id) + + # If both start and end nodes exist, create an edge + if start_node and end_node: + edge = { + "start_id": start_id, + "end_id": end_id, + "type": f"{start_node.type}_route_{end_node.type}", + "attributes": {} + } + edges = [GEdge(**edge)] # Create an edge object + self.ekg_construct_service.add_edges(edges, teamid) # Add edges to the construct service + + def run( + self, + query: str, + scene: str = "NEXA", + rootid: str = "ekg_team_default", + ): + """Run the EKG processing with the provided query and scene.""" + import uuid + sessionId = str(uuid.uuid4()).replace("-", "") # Generate a unique session ID + request = LingSiResponse( + observation={"content": query}, + intentionData=query, + startRootNodeId=rootid, + sessionId=sessionId, + scene=scene, + ) + logger.error(query) + + summary = "" # Initialize summary variable + history_done = [] + while True: + # Wait for the first completed future object + done, not_done = concurrent.futures.wait( + self.futures, return_when=concurrent.futures.FIRST_COMPLETED + ) + history_done.extend(done) + for future in done: + self.futures.remove(future) + + if history_done: + # for future in done: + future = history_done.pop(0) + try: + result = future.result() # Retrieve the result of the completed task + # logger.error(f"Task completed: {result}") + # self.futures.remove(future) # Remove completed task from the futures list + + # Assemble the new request with the result data + request = LingSiResponse( + observation={"toolResponse": result.get("toolResponse")}, + currentNodeId=result.get("currentNodeId"), + type=result.get("type"), + agentName=result.get("agentName"), + startRootNodeId=rootid, + sessionId=sessionId, + scene=scene, + ) + except Exception as e: + logger.error(f"Task generated an exception: {e}") + + # Perform inference using the request + if request: + # logger.error(f"{request}") + result = reasoning( + request.dict(), + self.memory_manager, + self.ekg_construct_service.gb, + self.intention_router, + self.llm_config + ) + # logger.error(f"{result}") + # Yield user interaction if present + if result.get("userInteraction"): + print(result["userInteraction"]) + yield result["userInteraction"] + + summary = summary or result.get("summary") # Update summary if empty + + # If a summary is available, yield it and break the loop + if summary: + print(summary) + yield summary + break + + # Update tasks in the pool based on the result + user_tasks = [] + toolPlans = result.get("toolPlan", []) or [] + for toolplan in toolPlans: + # 当存在关键信息,直接返回 " " + if "关键信息" in toolplan["toolDescription"]: + self.futures.append( + self.executor.submit( + self.empty_function, + **{ + "content": " ", + "toolResponse": " ", + "type": toolplan["type"], + "currentNodeId": toolplan["currentNodeId"], + "agentName": toolplan.get("currentNodeInfo"), + } + ) + ) + continue + + # if toolplan + if toolplan["type"] == "userProblem": + user_tasks.append(toolplan) + + if toolplan["type"] in ["onlyTool", "reactExecution"]: + # Submit function call to the executor for execution + future = self.executor.submit( + self.function_call, + **{ + "content": toolplan["toolDescription"], + "type": toolplan["type"], + "currentNodeId": toolplan["currentNodeId"], + "agentName": toolplan.get("currentNodeInfo"), + "memory": toolplan.get("memory"), + "toolDescription": toolplan.get("toolDescription") + } + ) + self.futures.append(future) + + # Process user tasks and gather user input + for user_task in user_tasks: + questionType = user_task.get( + "questionDescription").get("questionType") + user_query = user_task.get( + "questionDescription").get( + "questionContent").get( + "question") + if user_query is None or questionType is None: + continue + + print(user_query) + yield user_query # Yield the user query for input + user_answer = input() # Get user input + + # Submit the user answer as a future task + self.futures.append( + self.executor.submit( + self.empty_function, + **{ + "content": user_answer, + "toolResponse": user_answer, + "type": user_task["type"], + "currentNodeId": user_task["currentNodeId"], + "agentName": user_task["currentNodeInfo"] + } + ) + ) + + if not self.futures: # If there are no tasks, pause briefly before checking again + time.sleep(0.1) + + # Reset request to prepare for the next inference loop + request = None + toolPlans = None + result = None + + def empty_function(self, **kwargs) -> Dict: + """Return the input parameters as is (placeholder function).""" + return kwargs + + def function_call(self, **kwargs) -> Dict: + """Perform a single step function call and return the result.""" + + function_caller = FunctioncallAgent( + agent_name="codefuse_function_caller", # Set the agent name + project_config=self.project_config, # Provide the project configuration + tools=self.tools # Provide the tools available for use + ) + + query = Message( + role_type="user", + content=f"帮我选择匹配的工具并进行执行,工具描述为'{kwargs['content']}'" + ) + for msg in function_caller.step_stream(query, extra_params={"memory": kwargs.get("memory")}): + pass # Process the stream, if any + + observation = "" + # Extract the observation from the processed messages + if msg.parsed_contents: + observation = msg.parsed_contents[-1].get("Observation", "") + result = { + "toolResponse": observation, + "currentNodeId": kwargs.get("currentNodeId"), + "type": kwargs.get("type"), + "agentName": kwargs.get("agentName"), + } + return result # Return the result of the function call diff --git a/muagent/httpapis/ekg_construct/api.py b/muagent/httpapis/ekg_construct/api.py index 9dd8f8e..7565bbc 100644 --- a/muagent/httpapis/ekg_construct/api.py +++ b/muagent/httpapis/ekg_construct/api.py @@ -1,9 +1,7 @@ from fastapi import FastAPI from typing import Dict -import asyncio import uvicorn from loguru import logger -import tqdm import ollama import json import os @@ -126,6 +124,34 @@ async def update_llm_params(request: LLMRequest): answer=answer ) + # ~/llm/generate + @app.post("/functioncall/chat", response_model=LLMFCResponse) + async def fc_chat(request: LLMFCRequest): + # 添加预测逻辑的代码 + errorMessage = "ok" + successCode = True + choices = [] + try: + model_names = [i["name"] for i in ollama.list()["models"]] + if llm.model_type=="ollama" and llm.model_name not in model_names: + errorMessage = f"{llm.model_name} not in ollama.list {model_names}. " \ + f"please request llm/ollama/pull for downloading the ollama model" + successCode = False + else: + fc_output = llm.fc(request) + choices = fc_output.choices + except Exception as e: + logger.exception(e) + errorMessage = str(e) + successCode = False + + logger.info(f"choices.type: {type(choices)}") + logger.info(f"choices {choices}") + return LLMFCResponse( + successCode=successCode, errorMessage=errorMessage, + choices=choices + ) + # ~/embeddings/params @app.get("/embeddings/params", response_model=EmbeddingsParamsResponse) async def embedding_params(): diff --git a/muagent/llm_models/llm_shemas.py b/muagent/llm_models/llm_shemas.py new file mode 100644 index 0000000..45f54e9 --- /dev/null +++ b/muagent/llm_models/llm_shemas.py @@ -0,0 +1,50 @@ +from pydantic import BaseModel, Field +from typing import List, Dict, Optional, Union +from enum import Enum + + + +class ChatMessage(BaseModel): + role: str + content: str + + +class FunctionCallData(BaseModel): + name: str + arguments: Union[str, dict] + + +class ToolCall(BaseModel): + id: Optional[Union[str, int]] = None + type: str = "function" + function: FunctionCallData + + +class LLMOuputMessage(BaseModel): + content: Optional[str] = None + role: str + tool_calls: List[ToolCall] = [] + + +class Choice(BaseModel): + finish_reason: str + index: int = 0 + message: LLMOuputMessage + + +class UsageData(BaseModel): + completion_tokens: int + prompt_tokens: int + total_token: int + + +class LLMResponse(BaseModel): + choices: List[Choice] + created: int = 0 + id: str + model: str + object: str + usage: Optional[UsageData] = None + + + diff --git a/muagent/llm_models/openai_model.py b/muagent/llm_models/openai_model.py index c56b988..ba4ee14 100644 --- a/muagent/llm_models/openai_model.py +++ b/muagent/llm_models/openai_model.py @@ -1,5 +1,6 @@ import os -from typing import Union, Optional, List +import re +from typing import Union, Optional, List, Dict, Literal from loguru import logger from langchain.callbacks import AsyncIteratorCallbackHandler @@ -8,7 +9,20 @@ from langchain.llms.base import LLM from .llm_config import LLMConfig +from .llm_shemas import * + +try: + import ollama +except: + pass + # from configs.model_config import (llm_model_dict, LLM_MODEL) +def replacePrompt(prompt: str, keys: list[str] = []): + prompt = prompt.replace("{", "{{").replace("}", "}}") + for key in keys: + prompt = prompt.replace(f"{{{{{key}}}}}", f"{{{key}}}") + return prompt + class CustomLLMModel: @@ -32,6 +46,100 @@ def batch(self, prompts: str, stop: Optional[List[str]] = None): return [self(prompt, stop) for prompt in prompts] + def fc( + self, + messages: List[ChatMessage], + tools: List[Union[str, object]] = [], + system_prompt: Optional[str] = None, + tool_choice: Optional[Literal["auto", "required"]] = "auto", + parallel_tool_calls: Optional[bool] = None, + stop: Optional[List[str]] = None, + **kwargs + ) -> LLMResponse: + ''' + ''' + from muagent.base_configs.prompts.functioncall_template_prompt import ( + FUNCTION_CALL_PROMPT_en, + FC_AUTO_PROMPT_en, + FC_REQUIRED_PROMPT_en, + FC_PARALLEL_PROMPT_en, + FC_RESPONSE_PROMPT_en + ) + + use_tools = len(tools) > 0 + + prompts = [] + if use_tools: + prompts.append(FUNCTION_CALL_PROMPT_en) + + if system_prompt: + prompts.append(system_prompt) + + if use_tools and tool_choice =="auto": + prompts.append(FC_AUTO_PROMPT_en) + elif use_tools and tool_choice =="required": + prompts.append(FC_REQUIRED_PROMPT_en) + + if use_tools and parallel_tool_calls: + prompts.append(FC_PARALLEL_PROMPT_en) + + prompts.append("you are a helpful assistant to respond user's question:\n## Question Input\n{content}") + + if use_tools: + prompts.append(FC_RESPONSE_PROMPT_en) + + system_prompt = "\n".join(prompts) + # + content = "\n\n".join([f"{i.role}: {i.content}" for i in messages]) + content = "\n\n".join([f"{i.content}" for i in messages]) + if use_tools: + system_prompt = replacePrompt(system_prompt, keys=["content", "tool_desc"]) + prompt = system_prompt.format(content=content, tool_desc="\n".join(tools)) + else: + system_prompt = replacePrompt(system_prompt, keys=["content"]) + prompt = system_prompt.format(content=content) + + llm_output = self.predict(prompt) + + # logger.info(f"prompt: {prompt}") + # logger.info(f"llm_output: {llm_output}") + # parse llm functioncall + if "```json" in llm_output: + match_value = re.search(r'```json\n(.*?)```', llm_output, re.DOTALL) + else: + match_value = llm_output + + try: + function_call_output = json.loads(match_value.group(1).strip()) + except: + function_call_output = eval(match_value.group(1).strip()) + + function_call_output = function_call_output if isinstance(function_call_output, list) \ + else [function_call_output] + # + fc_response = LLMResponse( + choices=[Choice( + finish_reason="tool_calls", + message=LLMOuputMessage( + content=None, + role="assistant", + tool_calls=[ + ToolCall( + function=FunctionCallData( + name=fco["name"], + arguments=fco["arguments"], + ) + ) + for fco in function_call_output + ], + ) + )], + id="", + model="", + object="chat.completion", + usage=None + ) + return fc_response class OpenAILLMModel(CustomLLMModel): @@ -128,6 +236,24 @@ def __init__(self, llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler ) +class OllamaLLMModel(CustomLLMModel): + def __init__(self, llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None,): + self.llm = ollama.Client() + self.model_name = llm_config.model_name + + def __call__(self, prompt: str, + stop: Optional[List[str]] = None): + stream = self.llm.generate( + model=self.model_name, + prompt=prompt, + stream=True, + ) + answer = "" + for chunk in stream: + answer += chunk['response'] + return answer + + class KIMILLMModel(LYWWLLMModel): pass @@ -143,7 +269,7 @@ def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbac model_class_dict = { "openai": OpenAILLMModel, "lingyiwanwu": LYWWLLMModel, "kimi": KIMILLMModel, "moonshot": KIMILLMModel, - "qwen": QwenLLMModel, + "qwen": QwenLLMModel, "ollama": OllamaLLMModel } model_class = model_class_dict[llm_config.model_engine] model = model_class(llm_config, callBack) diff --git a/muagent/memory/__init__.py b/muagent/memory/__init__.py deleted file mode 100644 index 719b23b..0000000 --- a/muagent/memory/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .hierarchical_memory_manager import HierarchicalMemoryManager - -__all__ = [ - "HierarchicalMemoryManager" -] \ No newline at end of file diff --git a/muagent/memory_manager/__init__.py b/muagent/memory_manager/__init__.py new file mode 100644 index 0000000..59215ce --- /dev/null +++ b/muagent/memory_manager/__init__.py @@ -0,0 +1,13 @@ +from .hierarchical_memory_manager import HierarchicalMemoryManager +from .base_memory_manager import BaseMemoryManager +from .local_memory_manager import LocalMemoryManager +from .tbase_memory_manager import TbaseMemoryManager + + + +__all__ = [ + "BaseMemoryManager", + "LocalMemoryManager", + "TbaseMemoryManager", + "HierarchicalMemoryManager" +] \ No newline at end of file diff --git a/muagent/memory_manager/base_memory_manager.py b/muagent/memory_manager/base_memory_manager.py new file mode 100644 index 0000000..e2ad8e2 --- /dev/null +++ b/muagent/memory_manager/base_memory_manager.py @@ -0,0 +1,271 @@ +from abc import abstractmethod, ABC +from typing import ( + List, + Dict, + Optional +) +from loguru import logger + +from ..schemas import Memory, Message +from ..schemas.db import DBConfig, GBConfig, VBConfig, TBConfig +from ..schemas.models import ModelConfig +from ..db_handler import * + +# from muagent.orm import table_init +from muagent.db_handler import table_init + + +class BaseMemoryManager(ABC): + """ + This class represents a local memory manager that inherits from BaseMemoryManager. + + Attributes: + - memory_type: A string representing the memory type. Default is "recall". + - do_init: A boolean indicating whether to initialize. Default is False. + - recall_memory: An instance of Memory class representing the recall memory. + - save_message_keys: A list of strings representing the keys for saving messages. + + Methods: + - __init__: Initializes the LocalMemoryManager with the given user_name, unique_name, memory_type, and do_init. + - init_vb: Initializes the vb. + - append: Appends a message to the recall memory, current memory, and summary memory. + - extend: Extends the recall memory, current memory, and summary memory. + - load: Loads the memory from the specified directory and returns a Memory instance. + - router_retrieval: Routes the retrieval based on the retrieval type. + - embedding_retrieval: Retrieves messages based on embedding. + - text_retrieval: Retrieves messages based on text. + - datetime_retrieval: Retrieves messages based on datetime. + - recursive_summary: Performs recursive summarization of messages. + """ + + memory_manager_type: str = "base_memory_manager" + """The type of memory manager""" + + def __init__( + self, + vb_config: Optional[VBConfig] = None, + db_config: Optional[DBConfig] = None, + gb_config: Optional[GBConfig] = None, + tb_config: Optional[TBConfig] = None, + embed_config: Optional[ModelConfig] = None, + do_init: bool = False, + ): + """ + Initializes the LocalMemoryManager with the given parameters. + + Args: + - embed_config: EmbedConfig, the embedding model config + - llm_config: LLMConfig, the LLM model config + - db_config: DBConfig, the Database config + - vb_config: VBConfig, the vector base config + - gb_config: GBConfig, the graph base config + - do_init: A boolean indicating whether to initialize. Default is False. + """ + self.db_config = db_config + self.vb_config = vb_config + self.gb_config = gb_config + self.tb_config = tb_config + self.embed_config = embed_config + self.do_init = do_init + self.recall_memory_dict: Dict[str, Memory] = {} + self.save_message_keys = [ + 'session_index', 'role_name', 'role_type', 'role_prompt', 'input_query', + 'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list', + 'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] + + def init_handler(self, ): + """Initializes Database VectorBase GraphDB TbaseDB""" + self.init_vb() + self.init_tb() + self.init_db() + self.init_gb() + + def reinit_handler(self, do_init: bool=False): + self.init_vb() + self.init_tb() + self.init_db() + self.init_gb() + + def init_vb(self, do_init: bool=None): + """ + Initializes the vb. + """ + if self.vb_config: + table_init() + vb_dict = {"LocalFaissHandler": LocalFaissHandler} + vb_class = vb_dict.get(self.vb_config.vb_type, LocalFaissHandler) + self.vb: LocalFaissHandler = vb_class(self.embed_config, vb_config=self.vb_config) + + def init_db(self, ): + """Initializes Database VectorBase GraphDB TbaseDB""" + if self.db_config: + db_dict = {"LocalFaissHandler": LocalFaissHandler} + db_class = db_dict.get(self.db_config.db_type) + self.db = db_class(self.db_config) + + def init_tb(self, do_init: bool=None): + """ + Initializes the tb. + """ + if self.tb_config: + tb_dict = {"TbaseHandler": TbaseHandler} + tb_class = tb_dict.get(self.tb_config.tb_type, TbaseHandler) + self.tb = tb_class(self.tb_config, self.tb_config.index_name) + + def init_gb(self, do_init: bool=None): + """ + Initializes the gb. + """ + if self.gb_config: + gb_dict = {"NebulaHandler": NebulaHandler} + gb_class = gb_dict.get(self.gb_config.gb_type, NebulaHandler) + self.gb = gb_class(self.gb_config) + + def append(self, message: Message, role_tag: str): + """ + Appends a message to the recall memory, current memory, and summary memory. + + Args: + - message: An instance of Message class representing the message to be appended. + """ + pass + + def extend(self, memory: Memory, role_tag: str): + """ + Extends the recall memory, current memory, and summary memory. + + Args: + - memory: An instance of Memory class representing the memory to be extended. + """ + pass + + def load(self, load_dir: str = "") -> Memory: + """ + Loads the memory from the specified directory and returns a Memory instance. + + Args: + - load_dir: A string representing the directory to load the memory from. Default is KB_ROOT_PATH. + + Returns: + - An instance of Memory class representing the loaded memory. + """ + pass + + def get_memory_pool(self, session_index: str) -> Memory: + """ + return memory_pool + """ + pass + + def search_messages(self, text: str=None, n=5, **kwargs) -> List[Message]: + """ + return the search messages + + Args: + - text: A string representing the text for retrieval. Default is None. + - n: An integer representing the number of messages. Default is 5. + """ + + def router_retrieval(self, + session_index: str = "default", text: str=None, datetime: str = None, + n=5, top_k=5, retrieval_type: str = "embedding", **kwargs + ) -> Memory: + """ + Routes the retrieval based on the retrieval type. + + Args: + - text: A string representing the text for retrieval. Default is None. + - datetime: A string representing the datetime for retrieval. Default is None. + - n: An integer representing the number of messages. Default is 5. + - top_k: An integer representing the top k messages. Default is 5. + - retrieval_type: A string representing the retrieval type. Default is "embedding". + - **kwargs: Additional keyword arguments for retrieval. + + Returns: + - A list of Message instances representing the retrieved messages. + """ + retrieval_func_dict = { + "embedding": self.embedding_retrieval, + "text": self.text_retrieval, + "datetime": self.datetime_retrieval + } + + # 确保提供了合法的检索类型 + if retrieval_type not in retrieval_func_dict: + raise ValueError( + f"Invalid retrieval_type: '{retrieval_type}'. " + f"Available types: {list(retrieval_func_dict.keys())}" + ) + + retrieval_func = retrieval_func_dict[retrieval_type] + # + params = locals() + params.pop("self") + params.pop("retrieval_type") + params.update(params.pop('kwargs', {})) + # + return retrieval_func(**params) + + def embedding_retrieval(self, text: str, embed_model="", top_k=1, score_threshold=1.0, **kwargs) -> Memory: + """ + Retrieves messages based on embedding. + + Args: + - text: A string representing the text for retrieval. + - embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL. + - top_k: An integer representing the top k messages. Default is 1. + - score_threshold: A float representing the score threshold. Default is SCORE_THRESHOLD. + - **kwargs: Additional keyword arguments for retrieval. + + Returns: + - A list of Message instances representing the retrieved messages. + """ + pass + + def text_retrieval(self, text: str, **kwargs) -> Memory: + """ + Retrieves messages based on text. + + Args: + - text: A string representing the text for retrieval. + - **kwargs: Additional keyword arguments for retrieval. + + Returns: + - A list of Message instances representing the retrieved messages. + """ + pass + + def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> Memory: + """ + Retrieves messages based on datetime. + + Args: + - datetime: A string representing the datetime for retrieval. + - text: A string representing the text for retrieval. Default is None. + - n: An integer representing the number of messages. Default is 5. + - **kwargs: Additional keyword arguments for retrieval. + + Returns: + - A list of Message instances representing the retrieved messages. + """ + pass + + def recursive_summary(self, messages: List[Message], split_n: int = 20) -> Memory: + """ + Performs recursive summarization of messages. + + Args: + - messages: A list of Message instances representing the messages to be summarized. + - split_n: An integer representing the split n. Default is 20. + + Returns: + - A list of Message instances representing the summarized messages. + """ + pass + + def reranker(self, ): + """ + rerank the retrieval message from memory + """ + pass + diff --git a/muagent/memory/hierarchical_memory_manager.py b/muagent/memory_manager/hierarchical_memory_manager.py similarity index 96% rename from muagent/memory/hierarchical_memory_manager.py rename to muagent/memory_manager/hierarchical_memory_manager.py index 180dc35..99b40f8 100644 --- a/muagent/memory/hierarchical_memory_manager.py +++ b/muagent/memory_manager/hierarchical_memory_manager.py @@ -15,8 +15,8 @@ from muagent.connector.memory_manager import BaseMemoryManager from muagent.llm_models import * from muagent.base_configs.env_config import KB_ROOT_PATH -from muagent.orm import table_init - +# from muagent.orm import table_init +from muagent.db_handler import table_init from muagent.utils.common_utils import * diff --git a/muagent/memory_manager/local_memory_manager.py b/muagent/memory_manager/local_memory_manager.py new file mode 100644 index 0000000..bcb3180 --- /dev/null +++ b/muagent/memory_manager/local_memory_manager.py @@ -0,0 +1,443 @@ +from abc import abstractmethod, ABC +from typing import List, Dict +import os, sys, copy, json, uuid, random +from jieba.analyse import extract_tags +from collections import Counter +from loguru import logger +import numpy as np + +from langchain_community.docstore.document import Document + + +from .base_memory_manager import BaseMemoryManager + +from ..schemas import Memory, Message +from ..schemas.models import ModelConfig +from ..schemas.db import DBConfig, GBConfig, VBConfig, TBConfig + +from ..models import get_model + +from muagent.connector.configs.generate_prompt import * +from muagent.db_handler import * +from muagent.llm_models import getChatModelFromConfig +from muagent.llm_models.llm_config import EmbedConfig, LLMConfig +from muagent.utils.common_utils import * +from muagent.base_configs.env_config import KB_ROOT_PATH + + +class LocalMemoryManager(BaseMemoryManager): + """This class represents a LocalMemoryManager that inherits from BaseMemoryManager. + It provides functionalities to handle local memory storage and retrieval of messages. + """ + memory_manager_type: str = "local_memory_manager" + """The type of memory manager""" + + def __init__( + self, + embed_config: Union[ModelConfig, EmbedConfig], + llm_config: Union[LLMConfig, ModelConfig], + vb_config: Optional[VBConfig] = None, + db_config: Optional[DBConfig] = None, + gb_config: Optional[GBConfig] = None, + tb_config: Optional[TBConfig] = None, + do_init: bool = False, + kb_root_path: str = KB_ROOT_PATH, + ): + """Initialize the LocalMemoryManager with configurations. + + Args: + embed_config (Union[ModelConfig, EmbedConfig]): Configuration for embedding. + llm_config (Union[LLMConfig, ModelConfig]): Configuration for LLM. + vb_config (Optional[VBConfig], optional): Vector database configuration. + db_config (Optional[DBConfig], optional): Database configuration. + gb_config (Optional[GBConfig], optional): Graph database configuration. + tb_config (Optional[TBConfig], optional): Tbase configuration. + do_init (bool, optional): Flag indicating if initialization is required. + kb_root_path (str, optional): Path for storing knowledge base files (default is KB_ROOT_PATH). + """ + super().__init__( + vb_config or VBConfig(vb_type="LocalFaissHandler"), + db_config, gb_config, tb_config, + embed_config + ) + + self.do_init = do_init + self.kb_root_path = kb_root_path + self.embed_config: Union[ModelConfig, EmbedConfig] = embed_config + self.llm_config: Union[LLMConfig, ModelConfig] = llm_config + + # default + self.session_index: str = "default" + self.kb_name = f"{self.session_index}" + self.uuid_file = os.path.join( + self.kb_root_path, f"{self.session_index}/conversation.jsonl") + + self.recall_memory_dict: Dict[str, Memory] = {} + self.memory_uuids = set() + self.save_message_keys = [ + 'session_index', 'message_index', 'role_name', 'role_type', 'content', + 'input_text', 'role_tags', 'content', 'step_content', + 'parsed_content', 'spec_parsed_contents', 'global_kwargs', + 'start_datetime', 'end_datetime', + "keyword", "vector" + ] + # init from config + if isinstance(self.llm_config, LLMConfig): + self.model = getChatModelFromConfig(self.llm_config) + else: + self.model = get_model(self.llm_config) + self.init_handler() + self.load(do_init) + + def clear_local(self, re_init: bool = False, handler_type: str = None): + """Clear local memory and reinitialize if specified. + + Args: + re_init (bool, optional): Whether to reinitialize after clearing. + handler_type (str, optional): Type of handler to use (currently unused). + """ + if self.vb: # 存到了本地需要清理 + self.vb.clear_vs_local() + self.load(re_init) + + def append(self, message: Message, role_tag: str=None) -> None: + """Append a message to the local memory and update vector store if necessary. + + Args: + message (Message): The message to append. + role_tag (str, optional): An optional role tag for the message. + """ + # update the newest uuid_name + self.check_uuid_name(message) + datetimes = self.recall_memory_dict[self.session_index].get_datetimes() + contents = self.recall_memory_dict[self.session_index].get_contents() + message_indexes = self.recall_memory_dict[ + self.session_index].get_memory_values("message_index") + # if message not in chat history, no need to update + if message.message_index in message_indexes: + self.update2vb(message, role_tag) + elif ((message.end_datetime not in datetimes) or + ((message.input_text not in contents) and (message.content not in contents)) + ): + self.append2vb(message, role_tag) + + def append2vb(self, message: Message, role_tag: str=None) -> None: + """Append a message and its embeddings to the vector store (VB). + + Args: + message (Message): The message to append to the vector store. + role_tag (str, optional): Optional role tag for the message. + """ + if role_tag: + if isinstance(message.role_tags, list): + message.role_tags = list(set(message.role_tags + [role_tag])) + else: + message.role_tags += f", {role_tag}" + self.recall_memory_dict[self.session_index].append(message) + memory = self.recall_memory_dict[self.session_index] + # + docs, json_messages = self.message_process([message]) + if self.embed_config: + self.vb.add_docs(docs, kb_name=self.kb_name) + # + if True: # resave the local + _, json_messages = self.message_process(memory.messages) + save_to_json_file(json_messages, self.uuid_file) + + + def update2vb(self, message: Message, role_tag: str=None) -> None: + """Update an existing message in the vector store. + + Args: + message (Message): The message to update. + role_tag (str, optional): Optional role tag for the message. + """ + memory = self.recall_memory_dict[self.session_index] + memory.update(message, role_tag) + + # + docs, json_messages = self.message_process([message]) + # if self.embed_config: + # # search + # # delete + # # add + # self.vb.add_docs(docs, kb_name=self.kb_name) + # + if True: # resave the local + _, json_messages = self.message_process(memory.messages) + save_to_json_file(json_messages, self.uuid_file) + + + def extend(self, memory: Memory, role_tag: str=None): + """Append multiple messages from a Memory object to local memory. + + Args: + memory (Memory): The Memory object containing messages to append. + role_tag (str, optional): An optional role tag for messages. + """ + for message in memory.messages: + self.append(message, role_tag) + + def message_process(self, messages: List[Message]): + """Convert message objects to vector store/local data format. + + Args: + messages (List[Message]): List of messages to process. + + Returns: + Tuple[List[Document], dict]: Tuple containing documents for vector storage and a JSON representation of messages. + """ + messages = [{ + k: v for k, v in m.dict().items() + if k in self.save_message_keys + } + for m in messages + ] + docs = [{ + "page_content": m["step_content"] or m["content"] or m["input_text"], + "metadata": m} + for m in messages + ] + docs = [Document(**doc) for doc in docs] + # convert messages to local data-format + memory_messages = self.recall_memory_dict[self.session_index].dict() + json_messages = { + k: [ + {kkk: vvv for kkk, vvv in vv.items() + if kkk in self.save_message_keys} + for vv in v + ] + for k, v in memory_messages.items() + } + + return docs, json_messages + + def load(self, re_init=False) -> Memory: + """Load memory from files in the specified database root path. + + Args: + re_init (bool, optional): Flag indicating if reinitialization of memory should occur. + + Returns: + Memory: Loaded messages in memory format. + """ + if not re_init: + for root, dirs, files in os.walk(self.kb_root_path): + for file in files: + if file != 'conversation.jsonl': continue + file_path = os.path.join(root, file) + # get uuid_name + relative_path = os.path.relpath(root, self.kb_root_path) + path_parts = relative_path.split(os.sep) + uuid_name = "_".join(path_parts) + # load to local cache + recall_memory = Memory(**read_json_file(file_path)) + self.recall_memory_dict[uuid_name] = recall_memory + else: + self.recall_memory_dict = {} + + def get_memory_pool(self, session_index: str = "") -> Memory: + """Retrieve the memory pool for a specific session index. + + Args: + session_index (str, optional): Session index (default is empty string). + + Returns: + Memory: Retrieved messages in memory format. + """ + return self.recall_memory_dict.get(session_index, Memory(messages=[])) + + def embedding_retrieval( + self, + text: str, + top_k=1, + score_threshold=0.7, + session_index: str = "default", + **kwargs + ) -> List[Message]: + """Retrieve messages based on text embedding. + + Args: + text (str): The input text for embedding retrieval. + top_k (int, optional): The number of top results to retrieve (default is 1). + score_threshold (float, optional): Minimum score for message retrieval (default is 0.7). + session_index (str, optional): Session identifier (default is "default"). + + Returns: + Memory: Retrieved messages in memory format. + """ + if text is None: return Memory(messages=[]) + + # kb_name = self.get_vbname_from_sessionindex(session_index) + kb_name = session_index + docs = self.vb.search( + text, + top_k=top_k, + score_threshold=score_threshold, + kb_name=kb_name + ) + return Memory(messages=[Message(**doc.metadata) for doc, score in docs]) + + def text_retrieval( + self, + text: str, + session_index: str = "default", + **kwargs + ) -> Memory: + """Retrieve messages based on text content. + + Args: + text (str): The text to match against messages. + session_index (str, optional): Session identifier (default is "default"). + + Returns: + Memory: Messages matching the text content. + """ + if text is None: return Memory(messages=[]) + + # uuid_name = self.get_uuid_from_sessionindex(session_index) + messages = self.recall_memory_dict.get( + session_index, Memory(messages=[])).messages + return self._text_retrieval_from_cache( + messages, text, score_threshold=0.3, topK=5, **kwargs + ) + + def datetime_retrieval( + self, + session_index: str, + datetime: str, + text: str = None, + n: int = 5, + key: str = "start_datetime", + **kwargs + ) -> Memory: + """Retrieve messages based on date and time criteria. + + Args: + session_index (str): The session index to filter messages. + datetime (str): The datetime string reference for filtering. + text (str, optional): Optional text to match with messages. + n (int, optional): Number of minutes to define the range (default is 5). + key (str, optional): The key for datetime filtering (default is "start_datetime"). + + Returns: + Memory: Retrieved messages in memory format. + """ + if datetime is None: return Memory(messages=[]) + + # uuid_name = self.get_uuid_from_sessionindex(session_index) + messages = self.recall_memory_dict.get( + session_index, Memory(messages=[])).messages + return self._datetime_retrieval_from_cache( + messages, datetime, text, n, **kwargs + ) + + def _text_retrieval_from_cache( + self, + messages: List[Message], + text: str = None, + score_threshold=0.3, + topK=5, + tag_topK=5, + **kwargs + ) -> Memory: + keywords = extract_tags(text, topK=tag_topK) + + matched_messages = [] + for message in messages: + content = message.step_content or message.input_text or message.content + message_keywords = extract_tags(content, topK=tag_topK) + # calculate jaccard similarity + intersection = Counter(keywords) & Counter(message_keywords) + union = Counter(keywords) | Counter(message_keywords) + similarity = sum(intersection.values()) / sum(union.values()) + if similarity >= score_threshold: + matched_messages.append((message, similarity)) + matched_messages = sorted(matched_messages, key=lambda x:x[1]) + # return [m for m, s in matched_messages][:topK] + return Memory(messages=[m for m, s in matched_messages][:topK] ) + + def _datetime_retrieval_from_cache( + self, + messages: List[Message], + datetime: str, + text: str = None, + n: int = 5, **kwargs + ) -> Memory: + # select message by datetime + datetime_before, datetime_after = addMinutesToTimestamp(datetime, n) + select_messages = [ + message for message in messages + if datetime_before<=dateformatToTimestamp(message.end_datetime, 1, message.datetime_format)<=datetime_after + ] + return self._text_retrieval_from_cache(select_messages, text) + + def recursive_summary( + self, + messages: List[Message], + split_n: int = 20, + session_index: str="" + ) -> Memory: + """Generate a recursive summary of the provided messages. + + Args: + messages (List[Message]): List of messages to summarize. + split_n (int, optional): Number of messages to include in each summary pass (default is 20). + session_index (str, optional): Session identifier for the summary. + + Returns: + Memory: Updated messages including the summary. + """ + if len(messages) == 0: + return Memory(messages=[]) + + newest_messages = messages[-split_n:] + summary_messages = messages[:max(0, len(messages)-split_n)] + + while (len(newest_messages) != 0) and (newest_messages[0].role_type != "user"): + message = newest_messages.pop(0) + summary_messages.append(message) + + # summary + summary_content = '\n\n'.join([ + m.role_type + "\n" + "\n".join(([f"*{k}* {v}" for parsed_output in m.spec_parsed_contents for k, v in parsed_output.items() if k not in ['Action Status']])) + for m in summary_messages if m.role_type not in ["summary"] + ]) + + summary_prompt = createSummaryPrompt(conversation=summary_content) + content = self.model.predict(summary_prompt) + summary_message = Message( + session_index=session_index, + role_name="summaryer", + role_type="summary", + content=content, + step_content=content, + spec_parsed_contents=[], + global_kwargs={} + ) + summary_message.spec_parsed_contents.append({"summary": content}) + newest_messages.insert(0, summary_message) + + return Memory(messages=newest_messages) + + def check_uuid_name(self, message: Message = None): + if message.session_index != self.session_index: + self.session_index = message.session_index + # self.init_vb() + + self.kb_name = self.session_index + self.uuid_file = os.path.join(self.kb_root_path, f"{self.session_index}/conversation.jsonl") + + self.memory_uuids.add(self.session_index) + if self.session_index not in self.recall_memory_dict: + self.recall_memory_dict[self.session_index] = Memory(messages=[]) + + def modified_message(self, message: Message, update_rule_text: str) -> Message: + # 创建提示语,在更新规则文本中包含当前消息的内容 + prompt = f"结合以下更新内容修改当前消息内容:\n更新内容: {update_rule_text}\n\n当前消息内容:\n{message.role_content}\n\n请生成新的消息内容:" + + new_content = self.model.predict(prompt) + + message.content = new_content + + return message \ No newline at end of file diff --git a/muagent/memory_manager/tbase_memory_manager.py b/muagent/memory_manager/tbase_memory_manager.py new file mode 100644 index 0000000..49747d3 --- /dev/null +++ b/muagent/memory_manager/tbase_memory_manager.py @@ -0,0 +1,628 @@ +from typing import ( + List, + Union, + Optional, +) +import numpy as np +from jieba.analyse import extract_tags +import random +from collections import Counter +from loguru import logger +import json + +from .base_memory_manager import BaseMemoryManager + +from ..schemas import Memory, Message +from ..schemas.models import ModelConfig +from ..schemas.db import DBConfig, GBConfig, VBConfig, TBConfig + +from ..db_handler import * +from ..models import get_model + + +from muagent.llm_models import getChatModelFromConfig +from muagent.llm_models.llm_config import EmbedConfig, LLMConfig +from muagent.connector.configs.generate_prompt import * +from muagent.utils.common_utils import * + +from muagent.llm_models.get_embedding import get_embedding +from redis.commands.search.field import ( + TextField, + NumericField, + VectorField, + TagField +) + +DIM = 768 +MESSAGE_SCHEMA = [ + TextField("session_index", ), + TextField("message_index", ), + TextField("node_index"), + TextField("role_name",), + TextField("role_type", ), + TextField('input_text'), + TextField("content", ), + TextField("role_tags"), + TextField("parsed_output"), + TextField("global_kwargs",), + NumericField("start_datetime",) , + NumericField("end_datetime",), + VectorField("vector", + 'FLAT', + { + "TYPE": "FLOAT32", + "DIM": DIM, + "DISTANCE_METRIC": "COSINE" + }), + TagField(name='keyword', separator='|') +] + + + +class TbaseMemoryManager(BaseMemoryManager): + """ + This class represents a TbaseMemoryManager that inherits from BaseMemoryManager. + """ + + memory_manager_typy = "tbase_memory_manager" + """The type of memory manager for identification purposes.""" + + def __init__( + self, + embed_config: Union[ModelConfig, EmbedConfig], + llm_config: Union[LLMConfig, ModelConfig], + tbase_handler: TbaseHandler = None, + use_vector: bool = False, + vb_config: Optional[VBConfig] = None, + db_config: Optional[DBConfig] = None, + gb_config: Optional[GBConfig] = None, + tb_config: Optional[TBConfig] = None, + do_init: bool = False, + ): + """Initialize the TbaseMemoryManager with specified configurations. + + Args: + embed_config (Union[ModelConfig, EmbedConfig]): Configuration for embedding. + llm_config (Union[LLMConfig, ModelConfig]): Configuration for the LLM. + tbase_handler (TbaseHandler, optional): Handler for Tbase database access. + use_vector (bool, optional): Flag to specify whether to use vector embeddings. + vb_config (Optional[VBConfig], optional): Configuration for the vector database. + db_config (Optional[DBConfig], optional): Configuration for the main database. + gb_config (Optional[GBConfig], optional): Configuration for graph database. + tb_config (Optional[TBConfig], optional): Configuration for Tbase. + do_init (bool, optional): Flag to indicate if initialization is required. + """ + + super().__init__(vb_config, db_config, gb_config, tb_config) + self.do_init = do_init + self.embed_config: Union[ModelConfig, EmbedConfig] = embed_config + self.llm_config: Union[LLMConfig, ModelConfig] = llm_config + self.tb: TbaseHandler = tbase_handler + self.save_message_keys = [ + 'session_index', 'message_index', 'node_index', 'role_name', 'role_type', 'content', + 'input_text', 'role_tags', 'content', 'step_content', + 'parsed_content', 'spec_parsed_contents', 'global_kwargs', + 'start_datetime', 'end_datetime', + "keyword", "vector" + ] + self.use_vector = use_vector + self.init_handler() + self.init_tb_index() + + def init_tb_index(self, do_init: bool=None): + """Initialize the Tbase index if it does not already exist. + + Args: + do_init (bool, optional): Optional flag for initialization (unused here). + """ + # Create index if it does not exist + if not self.tb.is_index_exists(): + res = self.tb.create_index(schema=MESSAGE_SCHEMA) + logger.info(res) + + def append(self, message: Message, role_tag: str=None) -> None: + """Append a message to the Tbase memory. + + Args: + message (Message): The message to be appended. + role_tag (str, optional): Optional role tag for the message. + """ + tbase_message = self.localMessage2TbaseMessage(message, role_tag) # Convert local message to Tbase format + self.tb.insert_data_hash(tbase_message) # Insert into Tbase + + def extend(self, memory: Memory, role_tag: str=None) -> None: + """Append multiple messages from memory to Tbase. + + Args: + memory (Memory): The memory containing messages to append. + role_tag (str, optional): Optional role tag for all messages. + """ + for message in memory.messages: + self.append(message, role_tag) # Append each message + + def append_tools(self, tool_information: dict, session_index: str, nodeid: str, node_index: str="default") -> None: + """Append tool-related information to Tbase as messages. + + Args: + tool_information (dict): Dictionary containing tool information. + session_index (str): Session identifier. + nodeid (str): Graph node ID. + node_index (str, optional): Node index for differentiating nodes. + """ + tool_map = { + "toolKey": {"role_name": "tool_selector", "role_type": "assistant", + "customed_keys": ["toolDef"] + }, + "toolParam": {"role_name": "tool_filler", "role_type": "assistant"}, + "toolResponse": {"role_name": "function_caller", "role_type": "observation"}, + "toolSummary": {"role_name": "function_summary", "role_type": "Summary"}, + } + + for k, v in tool_map.items(): + try: + message = Message( + session_index=session_index, + message_index= f"{nodeid}_{k}", + node_index=node_index, + role_name = v["role_name"], # Assign role name + role_type = v["role_type"], # Assign role type + content = tool_information[k], # Assign tool information content + global_kwargs = { + **{kk: vv for kk, vv in tool_information.items() + if kk in v.get("customed_keys", [])} + } # Store additional tool information + ) + except: + pass + self.append(message) # Append the message to Tbase + + + def get_memory_by_sessionindex_tags(self, session_index: str, tags: List[str], limit: int = 10) -> Memory: + """Retrieve memory messages by session index and tags. + + Args: + session_index (str): The session index to search for. + tags (List[str]): List of tags to match against messages. + limit (int, optional): The maximum number of messages to retrieve (default is 10). + + Returns: + Memory: Retrieved messages in memory format. + """ + tags_str = '|'.join([f"*{tag}*" for tag in tags]) # Create a tags search string + querys = [ + f"@session_index:{session_index}", # Query for session index + f"@role_tags:{tags_str}", # Query for role tags + ] + query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys) # Combine queries + r = self.tb.search(query, limit=limit) # Search Tbase + return self.tbasedoc2Memory(r) # Convert results to Memory format + + def get_memory_by_chatindex_tags(self, chat_index: str, tags: List[str], limit: int = 10) -> Memory: + """Retrieve memory messages by chat index and tags. + + Args: + chat_index (str): The chat index to search for. + tags (List[str]): List of tags to match against messages. + limit (int, optional): The maximum number of messages to retrieve (default is 10). + + Returns: + Memory: Retrieved messages in memory format. + """ + tags_str = '|'.join([f"*{tag}*" for tag in tags]) # Create a tags search string + querys = [ + f"@session_index:{chat_index}", # Query for session index + f"@role_tags:{tags_str}", # Query for role tags + ] + query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys) # Combine queries + logger.debug(f"{query}") + r = self.tb.search(query, limit=limit) # Search Tbase + return self.tbasedoc2Memory(r) # Convert results to Memory format + + def get_memory_pool(self, session_index: str = "") -> Memory: + """Get the memory pool for a specific session index. + + Args: + session_index (str, optional): Session index (default is empty string). + + Returns: + Memory: Retrieved messages in memory format. + """ + return self.get_memory_pool_by_all({"session_index": session_index}) # Retrieve all memory for session + + def get_memory_pool_by_content(self, content: str) -> Memory: + """Get memory pool based on content search. + + Args: + content (str): Content to search for in messages. + + Returns: + Memory: Retrieved messages in memory format. + """ + r = self.tb.search(content) # Search Tbase + return self.tbasedoc2Memory(r) # Convert results to Memory format + + def get_memory_pool_by_key_content(self, key: str, content: str) -> Memory: + """Get memory pool based on key and content search. + + Args: + key (str): Key to search for in messages. + content (str): Content to search for in messages. + + Returns: + Memory: Retrieved messages in memory format. + """ + if key == "keyword": + query = f"@{key}:{{{content}}}" # Special handling for keywords + else: + query = f"@{key}:{content}" # General query + r = self.tb.search(query) # Search Tbase + return self.tbasedoc2Memory(r) # Convert results to Memory format + + def get_memory_pool_by_all(self, search_key_contents: dict, limit: int =10) -> Memory: + """Get memory pool based on multiple search criteria. + + Args: + search_key_contents (dict): Dictionary containing key-value pairs for searching messages. + limit (int, optional): The maximum number of messages to retrieve (default is 10). + + Returns: + Memory: Retrieved messages in memory format. + """ + querys = [] + for k, v in search_key_contents.items(): + if not v: continue + if k == "keyword": + querys.append(f"@{k}:{{{v}}}") + elif k == "role_tags": + tags_str = '|'.join([f"*{tag}*" for tag in v]) if isinstance(v, list) else f"{v}" + querys.append(f"@role_tags:{tags_str}") + elif k == "start_datetime": + query = f"(@start_datetime:[{v[0]} {v[1]}])" + querys.append(query) + elif k == "end_datetime": + query = f"(@end_datetime:[{v[0]} {v[1]}])" + querys.append(query) + else: + querys.append(f"@{k}:{v}") + + query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys) + r = self.tb.search(query, limit=limit) + return self.tbasedoc2Memory(r) + + def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, session_index: str = "default", **kwargs) -> Memory: + """Retrieve memory using vector embeddings based on input text. + + Args: + text (str): The input text for which embeddings are generated. + top_k (int, optional): Number of top results to retrieve (default is 1). + score_threshold (float, optional): Minimum score for fetching results (default is 1.0). + session_index (str, optional): Session identifier (default is "default"). + + Returns: + Memory: Retrieved messages in memory format. + """ + if text is None: return Memory(messages=[]) + if not self.use_vector and self.embed_config: + logger.error(f"can't use vector search, because the use_vector is {self.use_vector}") + return Memory(messages=[]) + + if self.use_vector and self.embed_config: + query_embedding = self._get_embedding_array(text) + + base_query = f'(@session_index:{session_index})=>[KNN {top_k} @vector $vector AS distance]' + query_params = {"vector": query_embedding} + r = self.tb.vector_search(base_query, query_params=query_params) + return self.tbasedoc2Memory(r) + + def text_retrieval(self, text: str, session_index: str = "default", **kwargs) -> Memory: + """Retrieve messages based on text content and session index. + + Args: + text (str): The text to search for. + session_index (str, optional): Session identifier (default is "default"). + + Returns: + Memory: Retrieved messages in memory format. + """ + keywords = extract_tags(text, topK=-1) + if len(keywords) > 0: + keyword = "|".join(keywords) + query = f"(@session_index:{session_index})(@keyword:{{{keyword}}})" + else: + query = f"@session_index:{session_index}" + # logger.debug(f"text_retrieval query: {query}") + r = self.tb.search(query) + memory = self.tbasedoc2Memory(r) + return self._text_retrieval_from_cache(memory.messages, text) + + def datetime_retrieval( + self, + session_index: str, + datetime: str, + text: str = None, + n: int = 5, + key: str = "start_datetime", + **kwargs + ) -> Memory: + """Retrieve messages based on datetime range and session index. + + Args: + session_index (str): The session index to filter messages. + datetime (str): The timestamp used for filtering messages. + text (str, optional): Optional text to retrieve alongside datetime. + n (int, optional): Number of minutes to define the range (default is 5). + key (str, optional): The key for datetime filtering (default is "start_datetime"). + + Returns: + Memory: Retrieved messages in memory format. + """ + + intput_timestamp = None + for datetime_format in ["%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S"]: + try: + intput_timestamp = dateformatToTimestamp(datetime, 1000, datetime_format) + break + except: + pass + if intput_timestamp is None: + raise ValueError(f"can't transform datetime into [%Y-%m-%d %H:%M:%S.%f, %Y-%m-%d %H:%M:%S]") + + query = f"(@session_index:{session_index})(@{key}:[{intput_timestamp-n*60*1000} {intput_timestamp+n*60*1000}])" + # logger.debug(f"datetime_retrieval query: {query}") + r = self.tb.search(query) + memory = self.tbasedoc2Memory(r) + return self._text_retrieval_from_cache(memory.messages, text) + + def _text_retrieval_from_cache( + self, + messages: List[Message], + text: str = None, + score_threshold=0.3, + topK=5, + tag_topK=5, + **kwargs + ) -> Memory: + """Retrieve messages based on text similarity from cached messages.""" + + if text is None: + return Memory(messages=messages[:topK]) + + if len(messages) < topK: + return Memory(messages=messages) + + keywords = extract_tags(text, topK=tag_topK) + + matched_messages = [] + for message in messages: + message_keywords = extract_tags( + message.step_content or message.content or message.input_text, + topK=tag_topK + ) + # calculate jaccard similarity + intersection = Counter(keywords) & Counter(message_keywords) + union = Counter(keywords) | Counter(message_keywords) + similarity = sum(intersection.values()) / sum(union.values()) + if similarity >= score_threshold: + matched_messages.append((message, similarity)) + matched_messages = sorted(matched_messages, key=lambda x:x[1]) + return Memory(messages=[m for m, s in matched_messages][:topK]) + + def recursive_summary( + self, + messages: List[Message], + session_index: str, + split_n: int = 20 + ) -> Memory: + """Generate a recursive summary of the provided messages. + + Args: + messages (List[Message]): List of messages to summarize. + session_index (str): Session identifier for the summary. + split_n (int, optional): Number of messages to include in each summary pass (default is 20). + + Returns: + Memory: Updated messages including the summary. + """ + + if len(messages) == 0: + return Memory(messages=messages) + + newest_messages = messages[-split_n:] + summary_messages = messages[:len(messages)-split_n] + + while (len(newest_messages) != 0) and (newest_messages[0].role_type != "user"): + message = newest_messages.pop(0) + summary_messages.append(message) + + # summary + model = self._get_model() + summary_content = '\n\n'.join([ + m.role_type + "\n" + "\n".join(([f"*{k}* {v}" for parsed_output in m.spec_parsed_contents for k, v in parsed_output.items() if k not in ['Action Status']])) + for m in summary_messages if m.role_type not in ["summary"] + ]) + # summary_prompt = CONV_SUMMARY_PROMPT_SPEC.format(conversation=summary_content) + summary_prompt = createSummaryPrompt(conversation=summary_content) + logger.debug(f"{summary_prompt}") + content = model.predict(summary_prompt) + summary_message = Message( + session_index=session_index, + role_name="summaryer", + role_type="summary", + content=content, + step_content=content, + parsed_output_list=[], + global_kwargs={} + ) + summary_message.spec_parsed_contents.append({"summary": content}) + newest_messages.insert(0, summary_message) + return Memory(messages=newest_messages) + + def localMessage2TbaseMessage(self, message: Message, role_tag: str= None): + """Convert a local Message object to a format suitable for Tbase storage.""" + + r = self.tb.search(f"@message_index: {message.message_index}") + history_role_tags = json.loads(r.docs[0].role_tags) if r.total == 1 else [] + + tbase_message = {} + for k, v in message.dict().items(): + v = list(set(history_role_tags+[role_tag])) if k=="role_tags" and role_tag else v + if isinstance(v, dict) or isinstance(v, list): + v = json.dumps(v, ensure_ascii=False) + tbase_message[k] = v + + tbase_message["start_datetime"] = dateformatToTimestamp(message.start_datetime, 1000, "%Y-%m-%d %H:%M:%S.%f") + tbase_message["end_datetime"] = dateformatToTimestamp(message.end_datetime, 1000, "%Y-%m-%d %H:%M:%S.%f") + + if self.use_vector and self.embed_config: + tbase_message["vector"] = self._get_embedding_array(message.content) + tbase_message["keyword"] = " | ".join(extract_tags(message.content, topK=-1) + + [tbase_message["message_index"].split("-")[0]]) + + tbase_message = { + k: v for k, v in tbase_message.items() + if k in self.save_message_keys + } + return tbase_message + + def tbasedoc2Memory(self, r_docs) -> Memory: + """Convert Tbase documents back into Memory objects.""" + + memory = Memory() + for doc in r_docs.docs: + tbase_message = {} + for k, v in doc.__dict__.items(): + if k in ["content", "input_text"]: + tbase_message[k] = v + continue + try: + v = json.loads(v) + except: + pass + + tbase_message[k] = v + + message = Message(**tbase_message) + memory.append(message) + + for message in memory.messages: + message.start_datetime = timestampToDateformat(int(message.start_datetime), 1000, "%Y-%m-%d %H:%M:%S.%f") + message.end_datetime = timestampToDateformat(int(message.end_datetime), 1000, "%Y-%m-%d %H:%M:%S.%f") + + memory.sort_by_key("end_datetime") + return memory + + + def init_global_msg(self, session_index: str, role_name: str, content: str, role_type: str = "global_value") -> bool: + """Initialize a global message and append it to the memory. + + Args: + session_index (str): The session index to which the message belongs. + role_name (str): The role name for the message. + content (str): The content of the message. + role_type (str, optional): The role type of the message (default is "global_value"). + + Returns: + bool: True if the message was initialized successfully; otherwise, False. + """ + + msg = Message(session_index=session_index, message_index = role_name ,role_name=role_name, role_type=role_type, content=content) + try: + self.append(msg) + return True + except Exception as e: + logger.error(f"Failed to initialize global message: {e}") + return False + + def get_msg_by_role_name(self, session_index: str, role_name: str) -> Optional[Message]: + """Retrieve a message by its role name within a session. + + Args: + session_index (str): The session index to search within. + role_name (str): The role name of the desired message. + + Returns: + Optional[Message]: The found message, or None if not found. + """ + + memory = self.get_memory_pool_by_all({"session_index": session_index, "role_name": role_name}) + # memory = self.get_memory_pool(session_index) + for msg in memory.messages: + if msg.role_name == role_name: + return msg + return None + + def get_msg_content_by_role_name(self, session_index: str, role_name: str) -> Optional[str]: + """Retrieve the content of a message by its role name. + + Args: + session_index (str): The session index to search within. + role_name (str): The role name of the desired message. + + Returns: + Optional[str]: The content of the found message, or None if not found. + """ + + message = self.get_msg_by_role_name(session_index, role_name) + if message == None: + return None + else: + return message.content + + def update_msg_content_by_rule(self, session_index: str, role_name: str, new_content: str,update_rule: str) -> bool: + """Update the content of a message based on an update rule. + + Args: + session_index (str): The session index to search within. + role_name (str): The role name of the message to update. + new_content (str): The new content to apply. + update_rule (str): The rule to apply for the update. + + Returns: + bool: True if the message was successfully updated; otherwise, False. + """ + + message = self.get_msg_by_role_name(session_index, role_name) + + if message == None: + return False + + prompt = f"{new_content}\n{role_name}:{message.content}\n{update_rule}" + model = self._get_model() + + new_content = model.predict(prompt) + + if new_content is not None: + message.content = new_content + self.append(message) + return True + else: + return False + + def _get_embedding(self, text) -> Dict[str, List[float]]: + text_vector = {} + if self.embed_config and text: + if isinstance(self.embed_config, ModelConfig): + self.emebd_model = get_model(self.embed_config) + vector = self.emebd_model.embed_query(text) + text_vector = {text: vector} + else: + text_vector = get_embedding( + self.embed_config.embed_engine, [text], + self.embed_config.embed_model_path, self.embed_config.model_device, + self.embed_config + ) + else: + text_vector = {text: [random.random() for _ in range(768)]} + return text_vector + + def _get_embedding_array(self, text) -> Dict[str, List[bytes]]: + text_vector = self._get_embedding(text) + return np.array(text_vector[text]).\ + astype(dtype=np.float32).tobytes() + + def _get_model(self, ): + if isinstance(self.llm_config, LLMConfig): + model = getChatModelFromConfig(self.llm_config) + else: + model = get_model(self.llm_config) + return model \ No newline at end of file diff --git a/muagent/models/__init__.py b/muagent/models/__init__.py new file mode 100644 index 0000000..b97b2d5 --- /dev/null +++ b/muagent/models/__init__.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +""" Import modules in models package.""" +from typing import Type +from loguru import logger + +from ..schemas.models import ModelConfig +from .base_model import ModelWrapperBase +from .openai_model import ( + OpenAIWrapperBase, + OpenAIChatWrapper, + # OpenAIDALLEWrapper, + OpenAIEmbeddingWrapper, +) +from .dashscope_model import ( + DashScopeChatWrapper, + # DashScopeImageSynthesisWrapper, + DashScopeTextEmbeddingWrapper, + # DashScopeMultiModalWrapper, +) +from .ollama_model import ( + OllamaChatWrapper, + OllamaEmbeddingWrapper, + # OllamaGenerationWrapper, +) +from .qwen_model import ( + QwenChatWrapper, + QwenTextEmbeddingWrapper +) +from .kimi_model import ( + KimiChatWrapper, + KimiEmbeddingWrapper +) +# from .gemini_model import ( +# GeminiChatWrapper, +# GeminiEmbeddingWrapper, +# ) +# from .zhipu_model import ( +# ZhipuAIChatWrapper, +# ZhipuAIEmbeddingWrapper, +# ) +# from .litellm_model import ( +# LiteLLMChatWrapper, +# ) +from .yi_model import ( + YiChatWrapper, +) + +__all__ = [ + "ModelWrapperBase", + "ModelResponse", + "PostAPIModelWrapperBase", + "PostAPIChatWrapper", + "OpenAIWrapperBase", + "OpenAIChatWrapper", + "OpenAIDALLEWrapper", + "OpenAIEmbeddingWrapper", + "DashScopeChatWrapper", + "DashScopeImageSynthesisWrapper", + "DashScopeTextEmbeddingWrapper", + "DashScopeMultiModalWrapper", + "OllamaChatWrapper", + "OllamaEmbeddingWrapper", + "OllamaGenerationWrapper", + "GeminiChatWrapper", + "GeminiEmbeddingWrapper", + "ZhipuAIChatWrapper", + "ZhipuAIEmbeddingWrapper", + "LiteLLMChatWrapper", + "YiChatWrapper", + "QwenChatWrapper", + "QwenTextEmbeddingWrapper", + "KimiChatWrapper", + "KimiEmbeddingWrapper" +] + + +def _get_model_wrapper(model_type: str) -> Type[ModelWrapperBase]: + """Get the specific type of model wrapper + + Args: + model_type (`str`): The model type name. + + Returns: + `Type[ModelWrapperBase]`: The corresponding model wrapper class. + """ + wrapper = ModelWrapperBase.get_wrapper(model_type=model_type) + if wrapper is None: + raise KeyError( + f"Unsupported model_type [{model_type}]," + "use PostApiModelWrapper instead.", + ) + return wrapper + + +def get_model(model_config: ModelConfig) -> ModelWrapperBase: + """Get the model by model config + + Args: + model_config (`ModelConfig`): The model config + + Returns: + `ModelWrapperBase`: The specific model + """ + return ModelWrapperBase.from_config(model_config) \ No newline at end of file diff --git a/muagent/models/base_model.py b/muagent/models/base_model.py new file mode 100644 index 0000000..7231839 --- /dev/null +++ b/muagent/models/base_model.py @@ -0,0 +1,504 @@ +""" +The implementation of this _ModelWrapperMeta are borrowed from +https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/model.py +""" + + +from __future__ import annotations +from abc import ABCMeta, abstractmethod +from typing import ( + Any, + Optional, + Type, + Union, + Sequence, + List, + Generator, + Literal, + Mapping +) +from loguru import logger +from openai.types.chat import ChatCompletion, ChatCompletionChunk + +from muagent.schemas import Message, Memory +from muagent.schemas.models import ( + ModelConfig, +) +from muagent.utils.common_utils import _convert_to_str + + +class _ModelWrapperMeta(ABCMeta): + """A meta call to replace the model wrapper's __call__ function with + wrapper about error handling.""" + + def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any: + if "__call__" in attrs: + attrs["__call__"] = attrs["__call__"] + return super().__new__(mcs, name, bases, attrs) + + def __init__(cls, name: Any, bases: Any, attrs: Any) -> None: + if not hasattr(cls, "_registry"): + cls._registry = {} + cls._type_registry = {} + cls._deprecated_type_registry = {} + else: + cls._registry[name] = cls + if hasattr(cls, "model_type"): + cls._type_registry[cls.model_type] = cls + if hasattr(cls, "deprecated_model_type"): + cls._deprecated_type_registry[ + cls.deprecated_model_type + ] = cls + super().__init__(name, bases, attrs) + + +class ModelWrapperBase(metaclass=_ModelWrapperMeta): + """The base class for model wrapper.""" + + model_type: str + """The type of the model wrapper, which is to identify the model wrapper + class in model configuration.""" + + config_name: str + """The name of the model configuration.""" + + model_name: str + """The name of the model, which is used in model api calling.""" + + api_key: Optional[str] = None + """The api key of the model, which is used in model api calling.""" + + api_url: Optional[str] = None + """The api url of the model, which is used in model api calling.""" + + def __init__( + self, # pylint: disable=W0613 + config_name: str, + model_name: str, + model_type: str = "codefuse", + api_key: Optional[str] = "model_base_xxx", + api_url: Optional[str]="https://codefuse.ai", + **kwargs: Any, + ) -> None: + """Base class for model wrapper. + + All model wrappers should inherit this class and implement the + `__call__` function. + + Args: + config_name (`str`): + The id of the model, which is used to extract configuration + from the config file. + model_name (`str`): + The name of the model. + api_key (`str`): + The api key of the model. + api_url (`str`): + The api url of the model. + model_type (`str`): + The type of the model wrapper. + """ + self.config_name = config_name + self.model_name = model_name + self.api_key = api_key + self.api_url = api_url + self.model_type = model_type + # logger.info(f"Initialize model by configuration [{config_name}]") + + @classmethod + def from_config(self, model_config: ModelConfig) -> 'ModelWrapperBase': + model_config_dict = model_config.dict() + model_type = model_config_dict.pop("model_type") + return self.get_wrapper(model_type)(**model_config_dict) + + @classmethod + def get_wrapper(cls, model_type: str) -> Type[ModelWrapperBase]: + """Get the specific model wrapper""" + if model_type in cls._type_registry: + return cls._type_registry[model_type] # type: ignore[return-value] + elif model_type in cls._registry: + return cls._registry[model_type] # type: ignore[return-value] + elif model_type in cls._deprecated_type_registry: + deprecated_cls = cls._deprecated_type_registry[model_type] + logger.warning( + f"Model type [{model_type}] will be deprecated in future " + f"releases, please use [{deprecated_cls.model_type}] instead.", + ) + return deprecated_cls # type: ignore[return-value] + else: + raise KeyError( + f"Unsupported model_type [{model_type}]," + "use PostApiModelWrapper instead.", + ) + + def __call__( + self, + prompt: str = None, + messages: Sequence[dict] = [], + tools: Sequence[object] = [], + *, + tool_choice: Optional[Literal['auto', 'required']] = None, + parallel_tool_calls: Optional[bool] = None, + stream: bool = None, + stop: Optional[str] = '', + format_type: Literal["str", "dict", "raw"] = "str", + **kwargs: Any, + ) -> Generator[Union[ChatCompletion, ChatCompletionChunk, str, Mapping], None, None]: + """Process input with the model. + + Args: + prompt (str, optional): The prompt string to provide to the model. + messages (Sequence[dict], optional): A sequence of messages for conversation context. + tools (Sequence[object], optional): Tools that can be utilized in the processing. + tool_choice (Optional[Literal['auto', 'required']], optional): Determining how to select tools. + parallel_tool_calls (Optional[bool], optional): If true, allows parallel calls to tools. + stream (bool, optional): If true, the output is streamed rather than returned all at once. + stop (Optional[str], optional): Token to signify stopping generation. + format_type (Literal["str", "dict", "raw"], optional): The format of the output. + **kwargs: Additional keyword arguments for extensibility. + + Returns: + Generator[Union[ChatCompletion, ChatCompletionChunk, str, Mapping], None, None]: + A generator yielding completion responses from the model. + """ + raise NotImplementedError( + f"Model Wrapper [{type(self).__name__}]" + f" is missing the required `__call__`" + f" method.", + ) + + def predict( + self, + prompt: str, + stop: Optional[str] = '', + ) -> Union[ChatCompletion, str]: + """Generate a prediction based on the provided prompt. + + Args: + prompt (str): The input prompt for prediction. + stop (Optional[str], optional): Token to signify stopping generation. + + Returns: + Union[ChatCompletion, str]: The model's prediction in the specified format. + """ + return self.generate(prompt, stop, "str") + + def generate( + self, + prompt: str, + stop: Optional[str] = '', + format_type: Literal["str", "raw"] = "raw", + ) -> Union[ChatCompletion, str]: + """Generate a response by calling the model. + + Args: + prompt (str): The input prompt. + stop (Optional[str], optional): Token to signify stopping generation. + format_type (Literal["str", "raw"], optional): The format of the output. + + Returns: + Union[ChatCompletion, str]: The generated response from the model. + """ + for i in self.__call__(prompt, stop=stop, stream=False, format_type=format_type): + pass + return i + + def generate_stream(self, + prompt: str, + stop: Optional[str] = '', + format_type: Literal["str", "raw"] = "raw", + ) -> Generator[Union[ChatCompletionChunk, str], None, None]: + """Stream the generated response from the model. + + Args: + prompt (str): The input prompt. + stop (Optional[str], optional): Token to signify stopping generation. + format_type (Literal["str", "raw"], optional): The format of the output. + + Yields: + Generator[Union[ChatCompletionChunk, str], None, None]: A generator yielding parts of the response. + """ + for i in self.__call__(prompt, stop=stop, stream=True, format_type=format_type): + yield i + + def chat(self, + messages: Optional[Sequence[dict]], + stop: Optional[str] = '', + format_type: Literal["str", "raw"] = "raw", + ) -> Union[ChatCompletion, str]: + """Process a chat message input and return the model's response. + + Args: + messages (Optional[Sequence[dict]]): A sequence of messages for conversation context. + stop (Optional[str], optional): Token to signify stopping generation. + format_type (Literal["str", "raw"], optional): The format of the output. + + Returns: + Union[ChatCompletion, str]: The model's chat response in the specified format. + """ + for i in self.__call__(None, messages, stop=stop, stream=False, format_type=format_type): + return i + + def chat_stream(self, + messages: Optional[Sequence[dict]], + stop: Optional[str] = '', + format_type: Literal["str", "raw"] = "raw", + ) -> Generator[Union[ChatCompletionChunk, str], None, None]: + """Stream chat responses from the model. + + Args: + messages (Optional[Sequence[dict]]): A sequence of messages for conversation context. + stop (Optional[str], optional): Token to signify stopping generation. + format_type (Literal["str", "raw"], optional): The format of the output. + + Yields: + Generator[Union[ChatCompletionChunk, str], None, None]: A generator yielding parts of the chat response. + """ + for i in self.__call__(None, messages, stop=stop, stream=True, format_type=format_type): + yield i + + def function_call( + self, + messages: Optional[Sequence[dict]] = None, + tools: Sequence[object] = [], + *, + prompt: Optional[str] = None, + tool_choice: Optional[Literal['auto', 'required']] = None, + parallel_tool_calls: Optional[bool] = None, + stream: Optional[bool] = False, + stop: Optional[str] = '', + format_type: Literal["raw"] = "raw", + ) -> Union[ChatCompletion, Mapping]: + """Call a function to process messages with optional tools. + + Args: + messages (Optional[Sequence[dict]], optional): A sequence of messages for context. + tools (Sequence[object], optional): Tools available for use. + prompt (Optional[str], optional): An optional prompt. + tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools. + parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls. + stream (Optional[bool], optional): If true, streams the output instead of returning it all at once. + stop (Optional[str], optional): Token to signify stopping generation. + format_type (Literal["raw"], optional): Specifies to return the output in raw format. + + Returns: + Union[ChatCompletion, Mapping]: The result of the function call processed by the model. + """ + kwargs = locals() + kwargs.pop("self") + for i in self.__call__(**kwargs): + pass + return i + + def function_call_stream( + self, + messages: Optional[Sequence[dict]] = None, + tools: Sequence[object] = [], + *, + prompt: Optional[str] = None, + tool_choice: Optional[Literal['auto', 'required']] = 'auto', + parallel_tool_calls: Optional[bool] = None, + stream: Optional[bool] = True, + stop: Optional[str] = '', + format_type: Literal["raw"] = "raw", + ) -> Generator[Union[ChatCompletionChunk, Mapping], None, None]: + """Stream function call outputs. + + Args: + messages (Optional[Sequence[dict]], optional): A sequence of messages for context. + tools (Sequence[object], optional): Tools available for use. + prompt (Optional[str], optional): An optional prompt. + tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools. + parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls. + stream (Optional[bool], optional): If true, streams the output. + stop (Optional[str], optional): Token to signify stopping generation. + format_type (Literal["raw"], optional): Specifies to return output in raw format. + + Yields: + Generator[Union[ChatCompletionChunk, Mapping], None, None]: A generator yielding parts of the function output. + """ + kwargs = locals() + kwargs.pop("self") + for i in self.__call__(**kwargs): + yield i + + def batch(self, *args: Any, **kwargs: Any) -> List[ChatCompletion]: + """Process batch inputs with the model. + + This method should be implemented by subclasses. + + Raises: + NotImplementedError: If not implemented in subclass. + """ + raise NotImplementedError( + f"Model Wrapper [{type(self).__name__}]" + f" is missing the required `batch`" + f" method.", + ) + + def embed_query(self, text: str) -> List[float]: + """Embed a query into a vector representation. + + This method should be implemented by subclasses. + + Args: + text (str): The text to embed. + + Raises: + NotImplementedError: If not implemented in subclass. + """ + raise NotImplementedError( + f"Model Wrapper [{type(self).__name__}]" + f" is missing the required `embed_query`" + f" method.", + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents into vector representations. + + This method should be implemented by subclasses. + + Args: + texts (List[str]): The list of texts to embed. + + Raises: + NotImplementedError: If not implemented in subclass. + """ + raise NotImplementedError( + f"Model Wrapper [{type(self).__name__}]" + f" is missing the required `embed_documents`" + f" method.", + ) + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> Union[List[dict], str]: + """Format the input messages into the format that the model + API required.""" + raise NotImplementedError( + f"Model Wrapper [{type(self).__name__}]" + f" is missing the required `format` method", + ) + + @staticmethod + def format_for_common_chat_models( + *args: Union[Message, Sequence[Message]], + ) -> List[dict]: + """A common format strategy for chat models, which will format the + input messages into a system message (if provided) and a user message. + + Note this strategy maybe not suitable for all scenarios, + and developers are encouraged to implement their own prompt + engineering strategies. + + The following is an example: + + .. code-block:: python + + prompt1 = model.format( + Message("system", "You're a helpful assistant", role="system"), + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + prompt2 = model.format( + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + The prompt will be as follows: + + .. code-block:: python + + # prompt1 + [ + { + "role": "system", + "content": "You're a helpful assistant" + }, + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + # prompt2 + [ + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + + Args: + args (`Union[Message, Sequence[Message]]`): + The input arguments to be formatted, where each argument + should be a `Message` object, or a list of `Message` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages. + """ + if len(args) == 0: + raise ValueError( + "At least one message should be provided. An empty message " + "list is not allowed.", + ) + + # Parse all information into a list of messages + input_Messages = [] + for _ in args: + if _ is None: + continue + if isinstance(_, Message): + input_Messages.append(_) + elif isinstance(_, list) and all(isinstance(__, Message) for __ in _): + input_Messages.extend(_) + else: + raise TypeError( + f"The input should be a Message object or a list " + f"of Message objects, got {type(_)}.", + ) + + # record dialog history as a list of strings + dialogue = [] + sys_prompt = None + for i, unit in enumerate(input_Messages): + if i == 0 and unit.role == "system": + # if system prompt is available, place it at the beginning + sys_prompt = _convert_to_str(unit.content) + else: + # Merge all messages into a conversation history prompt + dialogue.append( + f"{unit.name}: {_convert_to_str(unit.content)}", + ) + + content_components = [] + + # The conversation history is added to the user message if not empty + if len(dialogue) > 0: + content_components.extend(["## Conversation History"] + dialogue) + + messages = [ + { + "role": "user", + "content": "\n".join(content_components), + }, + ] + + # Add system prompt at the beginning if provided + if sys_prompt is not None: + messages = [{"role": "system", "content": sys_prompt}] + messages + + return messages \ No newline at end of file diff --git a/muagent/models/dashscope_model.py b/muagent/models/dashscope_model.py new file mode 100644 index 0000000..4aca149 --- /dev/null +++ b/muagent/models/dashscope_model.py @@ -0,0 +1,514 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for DashScope models""" +import os +from abc import ABC +from http import HTTPStatus +from typing import ( + Any, + Union, + List, + Sequence, + Optional, + Generator, + Literal +) + +from loguru import logger + +from ..schemas import Message + +try: + import dashscope + + dashscope_version = dashscope.version.__version__ + if dashscope_version < "1.19.0": + logger.warning( + f"You are using 'dashscope' version {dashscope_version}, " + "which is below the recommended version 1.19.0. " + "Please consider upgrading to maintain compatibility.", + ) + from dashscope.api_entities.dashscope_response import GenerationResponse +except ImportError: + dashscope = None + GenerationResponse = None + +from .base_model import ModelWrapperBase +from ..utils.common_utils import _convert_to_str + + + +class DashScopeWrapperBase(ModelWrapperBase, ABC): + """The model wrapper for DashScope API.""" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the DashScope wrapper. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in DashScope API. + api_key (`str`, default `None`): + The API key for DashScope API. + generate_args (`dict`, default `None`): + The extra keyword arguments used in DashScope api generation, + e.g. `temperature`, `seed`. + """ + if model_name is None: + model_name = config_name + logger.warning("model_name is not set, use config_name instead.") + + super().__init__(config_name=config_name, model_name=model_name) + + if dashscope is None: + raise ImportError( + "The package 'dashscope' is not installed. Please install it " + "by running `pip install dashscope>=1.19.0`", + ) + + self.generate_args = generate_args or {} + + self.api_key = api_key + self.max_length = None + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> Union[List[dict], str]: + raise RuntimeError( + f"Model Wrapper [{type(self).__name__}] doesn't " + f"need to format the input. Please try to use the " + f"model wrapper directly.", + ) + + +class DashScopeChatWrapper(DashScopeWrapperBase): + """The model wrapper for DashScope's chat API, refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/quick-start?spm=a2c4g.11186623.0.0.7e346eb5RvirBw + + ```json + { + "status_code": 200, + "request_id": "a75a1b22-e512-957d-891b-37db858ae738", + "code": "", + "message": "", + "output": { + "text": null, + "finish_reason": null, + "choices": [ + { + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "xxx" + } + } + ] + }, + "usage": { + "input_tokens": 25, + "output_tokens": 77, + "total_tokens": 102 + } + } + ``` + """ + + model_type: str = "dashscope_chat" + + deprecated_model_type: str = "tongyi_chat" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + stream: bool = False, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the DashScope wrapper. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in DashScope API. + api_key (`str`, default `None`): + The API key for DashScope API. + stream (`bool`, default `False`): + If True, the response will be a generator in the `stream` + field of the returned `ModelResponse` object. + generate_args (`dict`, default `None`): + The extra keyword arguments used in DashScope api generation, + e.g. `temperature`, `seed`. + """ + + super().__init__( + config_name=config_name, + model_name=model_name, + api_key=api_key, + generate_args=generate_args, + **kwargs, + ) + + self.stream = stream + + def __call__( + self, + prompt: str = None, + messages: Sequence[dict] = [], + tools: Sequence[object] = [], + *, + tool_choice: Optional[Literal['auto', 'required']] = None, + parallel_tool_calls: Optional[bool] = None, + stop: Optional[str] = '', + stream: Optional[bool] = None, + format_type: Literal['str', 'raw', 'dict'] = 'raw', + **kwargs: Any, + ) -> Generator: + """Processes a list of messages to construct a payload for the + DashScope API call. It then makes a request to the DashScope API + and returns the response. This method also updates monitoring + metrics based on the API response. + + Each message in the 'messages' list can contain text content and + optionally an 'image_urls' key. If 'image_urls' is provided, + it is expected to be a list of strings representing URLs to images. + These URLs will be transformed to a suitable format for the DashScope + API, which might involve converting local file paths to data URIs. + + Args: + messages (`list`): + A list of messages to process. + stream (`Optional[bool]`, default `None`): + The stream flag to control the response format, which will + overwrite the stream flag in the constructor. + **kwargs (`Any`): + The keyword arguments to DashScope chat completions API, + e.g. `temperature`, `max_tokens`, `top_p`, etc. Please + refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details + for more detailed arguments. + + Returns: + `ModelResponse`: + A response object with the response text in text field, and + the raw response in raw field. If stream is True, the response + will be a generator in the `stream` field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + The rule of roles in messages for DashScope is very rigid, + for more details, please refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details + """ + + messages = [{"role": "user", "content": prompt}] if prompt else messages + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: checking messages + if not isinstance(messages, list): + raise ValueError( + "Dashscope `messages` field expected type `list`, " + f"got `{type(messages)}` instead.", + ) + if not all("role" in msg and "content" in msg for msg in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for DashScope API.", + ) + + # step3: forward to generate response + if stream is None: + stream = self.stream + + kwargs.update( + { + "model": self.model_name, + "messages": messages, + # Set the result to be "message" format. + "result_format": "message", + "stream": stream, + "tools": tools, + "stop": stop, + }, + ) + + # Switch to the incremental_output mode + if stream: + kwargs["incremental_output"] = True + + response = dashscope.Generation.call(api_key=self.api_key, **kwargs) + + # step3: invoke llm api, record the invocation and update the monitor + if format_type == "str": + content = "" + if stream: + for chunk in response: + content += chunk["output"]["choices"][0]["message"]["content"] or '' + yield content + else: + yield response["output"]["choices"][0]["message"]["content"] + else: + if stream: + for chunk in response: + yield chunk + else: + yield response + + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> List[dict]: + """A common format strategy for chat models, which will format the + input messages into a user message. + + Note this strategy maybe not suitable for all scenarios, + and developers are encouraged to implement their own prompt + engineering strategies. + + The following is an example: + + .. code-block:: python + + prompt1 = model.format( + Message("system", "You're a helpful assistant", role="system"), + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + prompt2 = model.format( + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + The prompt will be as follows: + + .. code-block:: python + + # prompt1 + [ + { + "role": "system", + "content": "You're a helpful assistant" + }, + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + # prompt2 + [ + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + + Args: + args (`Union[Msg, Sequence[Msg]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages. + """ + + return ModelWrapperBase.format_for_common_chat_models(*args) + + + def format_prompt(self, *args: Union[Message, Sequence[Message]]) -> str: + """Forward the input to the model. + + Args: + args (`Union[Msg, Sequence[Msg]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + + Returns: + `str`: + The formatted string prompt. + """ + input_msgs: List[Message] = [] + for _ in args: + if _ is None: + continue + if isinstance(_, Message): + input_msgs.append(_) + elif isinstance(_, list) and all(isinstance(__, Message) for __ in _): + input_msgs.extend(_) + else: + raise TypeError( + f"The input should be a Msg object or a list " + f"of Msg objects, got {type(_)}.", + ) + + sys_prompt = None + dialogue = [] + for i, unit in enumerate(input_msgs): + if i == 0 and unit.role_type == "system": + # system prompt + sys_prompt = unit.content + else: + # Merge all messages into a conversation history prompt + dialogue.append( + f"{unit.role_name}: {unit.content}", + ) + + dialogue_history = "\n".join(dialogue) + + if sys_prompt is None: + prompt_template = "## Conversation History\n{dialogue_history}" + else: + prompt_template = ( + "{system_prompt}\n" + "\n" + "## Conversation History\n" + "{dialogue_history}" + ) + + return prompt_template.format( + system_prompt=sys_prompt, + dialogue_history=dialogue_history, + ) + +class DashScopeTextEmbeddingWrapper(DashScopeWrapperBase): + """The model wrapper for DashScope Text Embedding API. + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details?spm=a2c4g.11186623.0.i3 + + ```json + { + "status_code": 200, // 200 indicate success otherwise failed. + "request_id": "fd564688-43f7-9595-b986", // The request id. + "code": "", // If failed, the error code. + "message": "", // If failed, the error message. + "output": { + "embeddings": [ // embeddings + { + "embedding": [ // one embedding output + -3.8450357913970947, ..., + ], + "text_index": 0 // the input index. + } + ] + }, + "usage": { + "total_tokens": 3 // the request tokens. + } + } + ``` + """ + + model_type: str = "dashscope_text_embedding" + + def __call__( + self, + texts: Union[list[str], str], + dimension: Literal[512, 768, 1024, 1536] = 768, + **kwargs: Any, + ): + """Embed the messages with DashScope Text Embedding API. + + Args: + texts (`list[str]` or `str`): + The messages used to embed. + **kwargs (`Any`): + The keyword arguments to DashScope Text Embedding API, + e.g. `text_type`. Please refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details-15 + for more detailed arguments. + + Returns: + `ModelResponse`: + A list of embeddings in embedding field and the raw + response in raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved + for `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and + check the response generated by the model, which takes the + response as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: forward to generate response + response = dashscope.TextEmbedding.call( + input=texts, + model=self.model_name, + api_key=self.api_key, + dimension=dimension, + **kwargs, + ) + + if response.status_code != HTTPStatus.OK: + error_msg = ( + f" Request id: {response.request_id}," + f" Status code: {response.status_code}," + f" error code: {response.code}," + f" error message: {response.message}." + ) + raise RuntimeError(error_msg) + + # step5: return response + return response + + def embed_query(self, text: str) -> List[float]: + response = self([text]) + output = response["output"] + embeddings = output["embeddings"] + return embeddings[0]["embedding"] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + response = self(texts) + output = response["output"] + embeddings = output["embeddings"] + return [emb["embedding"] for emb in embeddings] + \ No newline at end of file diff --git a/muagent/models/kimi_model.py b/muagent/models/kimi_model.py new file mode 100644 index 0000000..e2ee93c --- /dev/null +++ b/muagent/models/kimi_model.py @@ -0,0 +1,372 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for OpenAI models +The implementation of this _ModelWrapperMeta are borrowed from +https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py +""" + + +from abc import ABC +from typing import ( + Union, + Any, + List, + Sequence, + Dict, + Optional, + Generator, + Literal +) +from urllib.parse import urlparse +import os +import base64 +from loguru import logger +try: + import openai +except ImportError as e: + raise ImportError( + "Cannot find openai package, please install it by " + "`pip install openai`", + ) from e + +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types import CreateEmbeddingResponse +from .base_model import ModelWrapperBase +from ..schemas import Message + + + +class KimiWrapperBase(ModelWrapperBase, ABC): + """The model wrapper for OpenAI API. + + Response: + - From https://platform.moonshot.cn/docs/intro + + ```json + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?", + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + ``` + """ + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + api_url: str = "https://api.moonshot.cn/v1", + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the openai client. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in OpenAI API. + api_key (`str`, default `None`): + The API key for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_API_KEY`. + organization (`str`, default `None`): + The organization ID for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_ORGANIZATION`. + client_args (`dict`, default `None`): + The extra keyword arguments to initialize the OpenAI client. + generate_args (`dict`, default `None`): + The extra keyword arguments used in openai api generation, + e.g. `temperature`, `seed`. + """ + + if model_name is None: + model_name = config_name + logger.warning("model_name is not set, use config_name instead.") + + init_params = locals() + init_params.pop("self") + init_params["model_type"] = self.model_type + super().__init__(**init_params) + # super().__init__(config_name=config_name, model_name=model_name) + + self.generate_args = generate_args or {} + self.api_url = api_url or "https://api.moonshot.cn/v1" + self.client = openai.OpenAI(api_key=api_key, base_url=self.api_url,) + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> Union[List[dict], str]: + raise RuntimeError( + f"Model Wrapper [{type(self).__name__}] doesn't " + f"need to format the input. Please try to use the " + f"model wrapper directly.", + ) + + +class KimiChatWrapper(KimiWrapperBase): + """The model wrapper for OpenAI's chat API.""" + + model_type: str = "moonshot_chat" + + substrings_in_vision_models_names = ["gpt-4-turbo", "vision", "gpt-4o"] + """The substrings in the model names of vision models.""" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + api_url: str = "https://api.moonshot.cn/v1", + stream: bool = False, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the openai client. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in OpenAI API. + api_key (`str`, default `None`): + The API key for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_API_KEY`. + organization (`str`, default `None`): + The organization ID for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_ORGANIZATION`. + client_args (`dict`, default `None`): + The extra keyword arguments to initialize the OpenAI client. + stream (`bool`, default `False`): + Whether to enable stream mode. + generate_args (`dict`, default `None`): + The extra keyword arguments used in openai api generation, + e.g. `temperature`, `seed`. + """ + + init_params = locals() + init_params.pop("self") + init_params["model_type"] = self.model_type + self.generate_args = generate_args + super().__init__(**init_params) + self.stream = stream + + def __call__( + self, + prompt: str = None, + messages: Sequence[dict] = [], + tools: Sequence[object] = [], + *, + tool_choice: Optional[Literal['auto', 'required']] = None, + parallel_tool_calls: Optional[bool] = None, + stop: Optional[str] = '', + stream: bool = None, + format_type: Literal['str', 'raw', 'dict'] = 'raw', + **kwargs: Any, + ) -> Generator[Union[ChatCompletionChunk, ChatCompletion], None, None]: + """Processes a list of messages to construct a payload for the OpenAI + API call. It then makes a request to the OpenAI API and returns the + response. This method also updates monitoring metrics based on the + API response. + + Each message in the 'messages' list can contain text content and + optionally an 'image_urls' key. If 'image_urls' is provided, + it is expected to be a list of strings representing URLs to images. + These URLs will be transformed to a suitable format for the OpenAI + API, which might involve converting local file paths to data URIs. + + Args: + messages (`list`): + A list of messages to process. + stream (`Optional[bool]`, defaults to `None`) + Whether to enable stream mode, which will override the + `stream` argument in the constructor if provided. + **kwargs (`Any`): + The keyword arguments to OpenAI chat completions API, + e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to + https://platform.openai.com/docs/api-reference/chat/create + for more detailed arguments. + + Returns: + `ModelResponse`: + The response text in text field, and the raw response in + raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + + messages = [{"role": "user", "content": prompt}] if prompt else messages + + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: checking messages + if not isinstance(messages, list): + raise ValueError( + "Kimi `messages` field expected type `list`, " + f"got `{type(messages)}` instead.", + ) + if not all("role" in Message and "content" in Message for Message in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for OpenAI API.", + ) + + # step3: forward to generate response + if stream is None: + stream = self.stream + + kwargs.update( + { + "model": self.model_name, + "messages": messages, + "stream": stream, + "tools": tools, + "tool_choice": tool_choice, + "parallel_tool_calls": parallel_tool_calls, + "stop": stop + }, + ) + + response = self.client.chat.completions.create(**kwargs) + + if format_type == "str": + content = "" + if stream: + for chunk in response: + content += chunk.choices[0].delta.content or '' + yield content + else: + yield response.choices[0].message.content + else: + if stream: + for chunk in response: + yield chunk + else: + yield response + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> List[dict]: + """Format the input string and dictionary into the format that + OpenAI Chat API required. If you're using a OpenAI-compatible model + without a prefix "gpt-" in its name, the format method will + automatically format the input messages into the required format. + + Args: + args (`Union[Message, Sequence[Message]]`): + The input arguments to be formatted, where each argument + should be a `Message` object, or a list of `Message` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages in the format that OpenAI Chat API + required. + """ + + return ModelWrapperBase.format_for_common_chat_models(*args) + + +class KimiEmbeddingWrapper(KimiWrapperBase): + """The model wrapper for OpenAI embedding API. + + Response: + - Refer to + https://xx + + ```json + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... (1536 floats total for ada-002) + -0.0028842222, + ], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + ``` + """ + + model_type: str = "kimi_embedding" + + def __call__( + self, + texts: Union[list[str], str], + **kwargs: Any, + ) -> CreateEmbeddingResponse: + """Embed the messages with OpenAI embedding API. + + Args: + texts (`list[str]` or `str`): + The messages used to embed. + **kwargs (`Any`): + The keyword arguments to OpenAI embedding API, + e.g. `encoding_format`, `user`. Please refer to + https://platform.openai.com/docs/api-reference/embeddings + for more detailed arguments. + + Returns: + `ModelResponse`: + A list of embeddings in embedding field and the + raw response in raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + raise NotImplementedError( + f"Model Wrapper [{type(self).__name__}]" + f" is missing the required `__call__` method", + ) + diff --git a/muagent/models/ollama_model.py b/muagent/models/ollama_model.py new file mode 100644 index 0000000..fd73f7e --- /dev/null +++ b/muagent/models/ollama_model.py @@ -0,0 +1,490 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for Ollama models.""" +import os +from abc import ABC +from typing import ( + Sequence, + Any, + Optional, + List, + Union, + Generator, + Literal, + Mapping +) + +from .base_model import ModelWrapperBase +from ..schemas import Message + + +class OllamaWrapperBase(ModelWrapperBase, ABC): + """The base class for Ollama model wrappers. + + To use Ollama API, please + 1. First install ollama server from https://ollama.com/download and + start the server + 2. Pull the model by `ollama pull {model_name}` in terminal + After that, you can use the ollama API. + """ + + model_type: str + """The type of the model wrapper, which is to identify the model wrapper + class in model configuration.""" + + model_name: str + """The model name used in ollama API.""" + + options: dict + """A dict contains the options for ollama generation API, + e.g. {"temperature": 0, "seed": 123}""" + + keep_alive: str + """Controls how long the model will stay loaded into memory following + the request.""" + + def __init__( + self, + config_name: str, + model_name: str, + api_key: str = '', + options: dict = None, + keep_alive: str = "5m", + api_url: Optional[Union[str, None]] = "http://127.0.0.1:11434", + **kwargs: Any, + ) -> None: + """Initialize the model wrapper for Ollama API. + + Args: + model_name (`str`): + The model name used in ollama API. + options (`dict`, default `None`): + The extra keyword arguments used in Ollama api generation, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `5m`): + Controls how long the model will stay loaded into memory + following the request. + host (`str`, default `None`): + The host port of the ollama server. + Defaults to `None`, which is 127.0.0.1:11434. + """ + + super().__init__(config_name=config_name, model_name=model_name) + + self.options = options + self.keep_alive = keep_alive + self.api_url = api_url or "http://127.0.0.1:11434" + + try: + import ollama + except ImportError as e: + raise ImportError( + "The package ollama is not found. Please install it by " + 'running command `pip install "ollama>=0.1.7"`', + ) from e + + self.client = ollama.Client(host=self.api_url) + + +class OllamaChatWrapper(OllamaWrapperBase): + """The model wrapper for Ollama chat API. + + Response: + - Refer to + https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion + + ```json + { + "model": "registry.ollama.ai/library/llama3:latest", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "Hello! How are you today?" + }, + "done": true, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000 + } + ``` + """ + + model_type: str = 'ollama_chat' + + def __init__( + self, + config_name: str, + model_name: str, + stream: bool = False, + options: dict = None, + keep_alive: str = "5m", + api_url: Optional[Union[str, None]] = None, + **kwargs: Any, + ) -> None: + """Initialize the model wrapper for Ollama API. + + Args: + model_name (`str`): + The model name used in ollama API. + stream (`bool`, default `False`): + Whether to enable stream mode. + options (`dict`, default `None`): + The extra keyword arguments used in Ollama api generation, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `5m`): + Controls how long the model will stay loaded into memory + following the request. + api_url (`str`, default `None`): + The host port of the ollama server. + Defaults to `None`, which is 127.0.0.1:11434. + """ + + super().__init__( + config_name=config_name, + model_name=model_name, + options=options, + keep_alive=keep_alive, + api_url=api_url, + **kwargs, + ) + + self.stream = stream + + def __call__( + self, + prompt: str = None, + messages: Sequence[dict] = [], + tools: Sequence[object] = [], + *, + tool_choice: Optional[Literal['auto', 'required']] = None, + parallel_tool_calls: Optional[bool] = None, + stop: Optional[str] = '', + stream: Optional[bool] = None, + options: Optional[dict] = None, + keep_alive: Optional[str] = None, + format_type: Literal['str', 'raw', 'dict'] = 'raw', + **kwargs: Any, + ): + """Generate response from the given messages. + + Args: + messages (`Sequence[dict]`): + A list of messages, each message is a dict contains the `role` + and `content` of the message. + stream (`bool`, default `None`): + Whether to enable stream mode, which will override the `stream` + input in the constructor. + options (`dict`, default `None`): + The extra arguments used in ollama chat API, which takes + effect only on this call, and will be merged with the + `options` input in the constructor, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `None`): + How long the model will stay loaded into memory following + the request, which takes effect only on this call, and will + override the `keep_alive` input in the constructor. + + Returns: + `ModelResponse`: + The response text in `text` field, and the raw response in + `raw` field. + """ + + messages = [{"role": "user", "content": prompt}] if prompt else messages + # step1: prepare parameters accordingly + if options is None: + options = self.options or {"stop": [stop] if stop else []} + else: + options = {**self.options, **options} + + keep_alive = keep_alive or self.keep_alive + + # step2: forward to generate response + stream = self.stream if stream is None else stream + + kwargs.update( + { + "model": self.model_name, + "messages": messages, + "tools": tools, + "stream": stream, + "options": options, + "keep_alive": keep_alive, + }, + ) + + response = self.client.chat(**kwargs) + if format_type == "str": + content = "" + if stream: + for chunk in response: + content += chunk["message"]["content"] or '' + yield content + else: + yield response["message"]["content"] + else: + if stream: + for chunk in response: + yield chunk + else: + yield response + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> List[dict]: + """Format the messages for ollama Chat API. + + All messages will be formatted into a single system message with + system prompt and conversation history. + + Note: + 1. This strategy maybe not suitable for all scenarios, + and developers are encouraged to implement their own prompt + engineering strategies. + 2. For ollama chat api, the content field shouldn't be empty string. + + Example: + + .. code-block:: python + + prompt = model.format( + Message("system", "You're a helpful assistant", role="system"), + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + The prompt will be as follows: + + .. code-block:: python + + [ + { + "role": "system", + "content": "You're a helpful assistant" + }, + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + + Args: + args (`Union[Message, Sequence[Message]]`): + The input arguments to be formatted, where each argument + should be a `Message` object, or a list of `Message` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages. + """ + + # Parse all information into a list of messages + input_msgs: List[Message] = [] + for _ in args: + if _ is None: + continue + if isinstance(_, Message): + input_msgs.append(_) + elif isinstance(_, list) and all(isinstance(__, Message) for __ in _): + input_msgs.extend(_) + else: + raise TypeError( + f"The input should be a Message object or a list " + f"of Message objects, got {type(_)}.", + ) + + # record dialog history as a list of strings + system_prompt = None + history_content_template = [] + dialogue = [] + # TODO: here we default the url links to images + images = [] + for i, unit in enumerate(input_msgs): + if i == 0 and unit.role_type == "system": + # system prompt + system_prompt = unit.content + else: + # Merge all messages into a conversation history prompt + dialogue.append( + f"{unit.role_name}: {unit.content}", + ) + + if unit.image_urls is not None: + images.extend(unit.image_urls) + + if len(dialogue) != 0: + dialogue_history = "\n".join(dialogue) + + history_content_template.extend( + ["## Conversation History", dialogue_history], + ) + + history_content = "\n".join(history_content_template) + + # The conversation history message + history_message = { + "role": "user", + "content": history_content, + } + + if len(images) != 0: + history_message["images"] = images + + if system_prompt is None: + return [history_message] + + return [ + {"role": "system", "content": system_prompt}, + history_message, + ] + + def format_prompt(self, *args: Union[Message, Sequence[Message]]) -> str: + """Forward the input to the model. + + Args: + args (`Union[Msg, Sequence[Msg]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + + Returns: + `str`: + The formatted string prompt. + """ + input_msgs: List[Message] = [] + for _ in args: + if _ is None: + continue + if isinstance(_, Message): + input_msgs.append(_) + elif isinstance(_, list) and all(isinstance(__, Message) for __ in _): + input_msgs.extend(_) + else: + raise TypeError( + f"The input should be a Msg object or a list " + f"of Msg objects, got {type(_)}.", + ) + + sys_prompt = None + dialogue = [] + for i, unit in enumerate(input_msgs): + if i == 0 and unit.role_type == "system": + # system prompt + sys_prompt = unit.content + else: + # Merge all messages into a conversation history prompt + dialogue.append( + f"{unit.role_name}: {unit.content}", + ) + + dialogue_history = "\n".join(dialogue) + + if sys_prompt is None: + prompt_template = "## Conversation History\n{dialogue_history}" + else: + prompt_template = ( + "{system_prompt}\n" + "\n" + "## Conversation History\n" + "{dialogue_history}" + ) + + return prompt_template.format( + system_prompt=sys_prompt, + dialogue_history=dialogue_history, + ) + + +class OllamaEmbeddingWrapper(OllamaWrapperBase): + """The model wrapper for Ollama embedding API. + + Response: + - Refer to + https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings + + ```json + { + "model": "all-minilm", + "embeddings": [[ + 0.010071029, -0.0017594862, 0.05007221, 0.04692972, + 0.008599704, 0.105441414, -0.025878139, 0.12958129, + ]] + } + ``` + """ + + model_type: str = "ollama_embedding" + + def __call__( + self, + texts: str, + options: Optional[dict] = None, + keep_alive: Optional[str] = None, + **kwargs: Any, + ) -> Mapping[str, Sequence[float]]: + """Generate embedding from the given prompt. + + Args: + prompt (`str`): + The prompt to generate response. + options (`dict`, default `None`): + The extra arguments used in ollama embedding API, which takes + effect only on this call, and will be merged with the + `options` input in the constructor, + e.g. `{"temperature": 0., "seed": 123}`. + keep_alive (`str`, default `None`): + How long the model will stay loaded into memory following + the request, which takes effect only on this call, and will + override the `keep_alive` input in the constructor. + + Returns: + `ModelResponse`: + The response embedding in `embedding` field, and the raw + response in `raw` field. + """ + # step1: prepare parameters accordingly + if options is None: + options = self.options + else: + options = {**self.options, **options} + + keep_alive = keep_alive or self.keep_alive + + # step2: forward to generate response + response = self.client.embed( + model=self.model_name, + input=texts, + options=options, + keep_alive=keep_alive, + **kwargs, + ) + # step5: return response + return response + + def embed_query(self, text: str) -> List[float]: + response = self([text]) + embeddings = response["embeddings"] + return embeddings[0] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + response = self(texts) + embeddings = response["embeddings"] + return embeddings + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> Union[List[dict], str]: + raise RuntimeError( + f"Model Wrapper [{type(self).__name__}] doesn't " + f"need to format the input. Please try to use the " + f"model wrapper directly.", + ) \ No newline at end of file diff --git a/muagent/models/openai_model.py b/muagent/models/openai_model.py new file mode 100644 index 0000000..338e530 --- /dev/null +++ b/muagent/models/openai_model.py @@ -0,0 +1,667 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for OpenAI models +The implementation of this _ModelWrapperMeta are borrowed from +https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py +""" + + +from abc import ABC +from typing import ( + Union, + Any, + List, + Sequence, + Dict, + Optional, + Generator, + Literal +) +from urllib.parse import urlparse +import os +import base64 +from loguru import logger +try: + import openai +except ImportError as e: + raise ImportError( + "Cannot find openai package, please install it by " + "`pip install openai`", + ) from e + +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types import CreateEmbeddingResponse +from .base_model import ModelWrapperBase +from ..schemas import Message + + + +OPENAI_MAX_LENGTH = { + "update": 20231212, + # gpt-4 + "gpt-4o-mini": 8192, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-32k": 32768, + "gpt-4-0613": 8192, + "gpt-4-32k-0613": 32768, + "gpt-4-0314": 8192, # legacy + "gpt-4-32k-0314": 32768, # legacy + # gpt-3.5 + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-3.5-turbo-0613": 4096, # legacy + "gpt-3.5-turbo-16k-0613": 16385, # deprecated on June 13th 2024 + "gpt-3.5-turbo-0301": 4096, # deprecated on June 13th 2024 + "text-davinci-003": 4096, # deprecated on Jan 4th 2024 + "text-davinci-002": 4096, # deprecated on Jan 4th 2024 + "code-davinci-002": 4096, # deprecated on Jan 4th 2024 + # gpt-3 legacy + "text-curie-001": 2049, + "text-babbage-001": 2049, + "text-ada-001": 2049, + "davinci": 2049, + "curie": 2049, + "babbage": 2049, + "ada": 2049, + # + "text-embedding-3-small": 8191, + "text-embedding-3-large": 8191, + "text-embedding-ada-002": 8191, +} + + +def get_openai_max_length(model_name: str) -> int: + """Get the max length of the OpenAi models.""" + try: + return OPENAI_MAX_LENGTH[model_name] + except KeyError as exc: + raise KeyError( + f"Model [{model_name}] not found in OPENAI_MAX_LENGTH. " + f"The last updated date is {OPENAI_MAX_LENGTH['update']}", + ) from exc + + + +def _to_openai_image_url(url: str) -> str: + """Convert an image url to openai format. If the given url is a local + file, it will be converted to base64 format. Otherwise, it will be + returned directly. + + Args: + url (`str`): + The local or public url of the image. + """ + # See https://platform.openai.com/docs/guides/vision for details of + # support image extensions. + support_image_extensions = ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ) + + parsed_url = urlparse(url) + + lower_url = url.lower() + + # Web url + if parsed_url.scheme != "": + if any(lower_url.endswith(_) for _ in support_image_extensions): + return url + + # Check if it is a local file + elif os.path.exists(url) and os.path.isfile(url): + if any(lower_url.endswith(_) for _ in support_image_extensions): + with open(url, "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode( + "utf-8", + ) + extension = parsed_url.path.lower().split(".")[-1] + mime_type = f"image/{extension}" + return f"data:{mime_type};base64,{base64_image}" + + raise TypeError(f"{url} should be end with {support_image_extensions}.") + + + +class OpenAIWrapperBase(ModelWrapperBase, ABC): + """The model wrapper for OpenAI API. + + Response: + - From https://platform.openai.com/docs/api-reference/chat/create + + ```json + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?", + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + ``` + """ + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + api_url: str = "https://api.openai.com/v1", + organization: str = None, + client_args: dict = None, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the openai client. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in OpenAI API. + api_key (`str`, default `None`): + The API key for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_API_KEY`. + organization (`str`, default `None`): + The organization ID for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_ORGANIZATION`. + client_args (`dict`, default `None`): + The extra keyword arguments to initialize the OpenAI client. + generate_args (`dict`, default `None`): + The extra keyword arguments used in openai api generation, + e.g. `temperature`, `seed`. + """ + + if model_name is None: + model_name = config_name + logger.warning("model_name is not set, use config_name instead.") + + init_params = locals() + init_params.pop("self") + init_params["model_type"] = self.model_type + super().__init__(**init_params) + # super().__init__(config_name=config_name, model_name=model_name) + + self.generate_args = generate_args or {} + + try: + from zdatafront import ZDataFrontClient + from zdatafront.openai import SyncProxyHttpClient + VISIT_DOMAIN = os.environ.get("visit_domain") + VISIT_BIZ = os.environ.get("visit_biz") + VISIT_BIZ_LINE = os.environ.get("visit_biz_line") + aes_secret_key = os.environ.get("aes_secret_key") + zdatafront_client = ZDataFrontClient( + visit_domain=VISIT_DOMAIN, + visit_biz=VISIT_BIZ, + visit_biz_line=VISIT_BIZ_LINE, + aes_secret_key=aes_secret_key + ) + http_client = SyncProxyHttpClient(zdatafront_client=zdatafront_client, prefer_async=True) + except Exception as e: + logger.warning("There is no zdatafront, act as openai") + http_client = None + + if http_client: + self.client = openai.OpenAI( + api_key=api_key, + http_client=http_client, + organization=organization, + **(client_args or {}), + timeout=120, + ) + else: + self.client = openai.OpenAI( + api_key=api_key, + organization=organization, + **(client_args or {}), + ) + # Set the max length of OpenAI model + try: + self.max_length = get_openai_max_length(self.model_name) + except Exception as e: + logger.warning( + f"fail to get max_length for {self.model_name}: " f"{e}", + ) + self.max_length = None + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> Union[List[dict], str]: + raise RuntimeError( + f"Model Wrapper [{type(self).__name__}] doesn't " + f"need to format the input. Please try to use the " + f"model wrapper directly.", + ) + + +class OpenAIChatWrapper(OpenAIWrapperBase): + """The model wrapper for OpenAI's chat API.""" + + model_type: str = "openai_chat" + + substrings_in_vision_models_names = ["gpt-4-turbo", "vision", "gpt-4o"] + """The substrings in the model names of vision models.""" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + api_url: str = "https://api.openai.com/v1", + organization: str = None, + client_args: dict = None, + stream: bool = False, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the openai client. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in OpenAI API. + api_key (`str`, default `None`): + The API key for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_API_KEY`. + organization (`str`, default `None`): + The organization ID for OpenAI API. If not specified, it will + be read from the environment variable `OPENAI_ORGANIZATION`. + client_args (`dict`, default `None`): + The extra keyword arguments to initialize the OpenAI client. + stream (`bool`, default `False`): + Whether to enable stream mode. + generate_args (`dict`, default `None`): + The extra keyword arguments used in openai api generation, + e.g. `temperature`, `seed`. + """ + + init_params = locals() + init_params.pop("self") + init_params["model_type"] = self.model_type + super().__init__(**init_params) + self.stream = stream + + def __call__( + self, + prompt: str = None, + messages: Sequence[dict] = [], + tools: Sequence[object] = [], + *, + tool_choice: Optional[Literal['auto', 'required']] = None, + parallel_tool_calls: Optional[bool] = None, + stream: bool = None, + stop: Optional[str] = '', + format_type: Literal['str', 'raw', 'dict'] = 'raw', + **kwargs: Any, + ) -> Generator[Union[ChatCompletionChunk, ChatCompletion], None, None]: + """Processes a list of messages to construct a payload for the OpenAI + API call. It then makes a request to the OpenAI API and returns the + response. This method also updates monitoring metrics based on the + API response. + + Each message in the 'messages' list can contain text content and + optionally an 'image_urls' key. If 'image_urls' is provided, + it is expected to be a list of strings representing URLs to images. + These URLs will be transformed to a suitable format for the OpenAI + API, which might involve converting local file paths to data URIs. + + Args: + messages (`list`): + A list of messages to process. + stream (`Optional[bool]`, defaults to `None`) + Whether to enable stream mode, which will override the + `stream` argument in the constructor if provided. + **kwargs (`Any`): + The keyword arguments to OpenAI chat completions API, + e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to + https://platform.openai.com/docs/api-reference/chat/create + for more detailed arguments. + + Returns: + `ModelResponse`: + The response text in text field, and the raw response in + raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + + messages = [{"role": "user", "content": prompt}] if prompt else messages + + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: checking messages + if not isinstance(messages, list): + raise ValueError( + "OpenAI `messages` field expected type `list`, " + f"got `{type(messages)}` instead.", + ) + if not all("role" in Message and "content" in Message for Message in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for OpenAI API.", + ) + + # step3: forward to generate response + if stream is None: + stream = self.stream + + kwargs.update( + { + "model": self.model_name, + "messages": messages, + "stream": stream, + "tools": tools, + "tool_choice": tool_choice, + "parallel_tool_calls": parallel_tool_calls, + "stop": stop, + }, + ) + + if stream: + kwargs["stream_options"] = {"include_usage": True} + + response = self.client.chat.completions.create(**kwargs) + + if format_type == "str": + content = "" + if stream: + for chunk in response: + content += chunk.choices[0].delta.content or '' + yield content + else: + yield response.choices[0].message.content + else: + if stream: + for chunk in response: + yield chunk + else: + yield response + + @staticmethod + def _format_Message_with_url( + message: Message, + model_name: str, + ) -> Dict: + """Format a message with image urls into openai chat format. + This format method is used for gpt-4o, gpt-4-turbo, gpt-4-vision and + other vision models. + """ + # Check if the model is a vision model + if not any( + _ in model_name + for _ in OpenAIChatWrapper.substrings_in_vision_models_names + ): + logger.warning( + f"The model {model_name} is not a vision model. " + f"Skip the url in the message.", + ) + return { + "role": message.role_type, + "name": message.role_name, + "content": message.content, + } + + # Put all urls into a list + urls = message.image_urls if isinstance(message.image_urls, list) else [message.image_urls] + + # Check if the url refers to an image + checked_urls = [] + for url in urls: + try: + checked_urls.append(_to_openai_image_url(url)) + except TypeError: + logger.warning( + f"The url {url} is not a valid image url for " + f"OpenAI Chat API, skipped.", + ) + + if len(checked_urls) == 0: + # If no valid image url is provided, return the normal message dict + return { + "role": message.role_type, + "name": message.role_name, + "content": message.content, + } + else: + # otherwise, use the vision format message + returned_Message = { + "role": message.role_type, + "name": message.role_name, + "content": [ + { + "type": "text", + "text": message.content, + }, + ], + } + + image_dicts = [ + { + "type": "image_url", + "image_url": { + "url": _, + }, + } + for _ in checked_urls + ] + + returned_Message["content"].extend(image_dicts) + + return returned_Message + + @staticmethod + def static_format( + *args: Union[Message, Sequence[Message]], + model_name: str, + ) -> List[dict]: + """A static version of the format method, which can be used without + initializing the OpenAIChatWrapper object. + + Args: + args (`Union[Message, Sequence[Message]]`): + The input arguments to be formatted, where each argument + should be a `Message` object, or a list of `Message` objects. + In distribution, placeholder is also allowed. + model_name (`str`): + The name of the model to use in OpenAI API. + + Returns: + `List[dict]`: + The formatted messages in the format that OpenAI Chat API + required. + """ + messages = [] + for arg in args: + if arg is None: + continue + if isinstance(arg, Message): + if arg.image_urls is not None and arg.image_urls: + # Format the message according to the model type + # (vision/non-vision) + formatted_Message = OpenAIChatWrapper._format_Message_with_url( + arg, + model_name, + ) + messages.append(formatted_Message) + else: + messages.append( + { + "role": arg.role_type, + "name": arg.role_name, + "content": arg.content, + }, + ) + + elif isinstance(arg, list): + messages.extend( + OpenAIChatWrapper.static_format( + *arg, + model_name=model_name, + ), + ) + else: + raise TypeError( + f"The input should be a Message object or a list " + f"of Message objects, got {type(arg)}.", + ) + + return messages + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> List[dict]: + """Format the input string and dictionary into the format that + OpenAI Chat API required. If you're using a OpenAI-compatible model + without a prefix "gpt-" in its name, the format method will + automatically format the input messages into the required format. + + Args: + args (`Union[Message, Sequence[Message]]`): + The input arguments to be formatted, where each argument + should be a `Message` object, or a list of `Message` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages in the format that OpenAI Chat API + required. + """ + + # Format messages according to the model name + if self.model_name.startswith("gpt-"): + return OpenAIChatWrapper.static_format( + *args, + model_name=self.model_name, + ) + else: + # The OpenAI library maybe re-used to support other models + return ModelWrapperBase.format_for_common_chat_models(*args) + + +class OpenAIEmbeddingWrapper(OpenAIWrapperBase): + """The model wrapper for OpenAI embedding API. + + Response: + - Refer to + https://platform.openai.com/docs/api-reference/embeddings/create + + ```json + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... (1536 floats total for ada-002) + -0.0028842222, + ], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + ``` + """ + + model_type: str = "openai_embedding" + + def __call__( + self, + texts: Union[list[str], str], + dimension=768, + **kwargs: Any, + ) -> CreateEmbeddingResponse: + """Embed the messages with OpenAI embedding API. + + Args: + texts (`list[str]` or `str`): + The messages used to embed. + **kwargs (`Any`): + The keyword arguments to OpenAI embedding API, + e.g. `encoding_format`, `user`. Please refer to + https://platform.openai.com/docs/api-reference/embeddings + for more detailed arguments. + + Returns: + `ModelResponse`: + A list of embeddings in embedding field and the + raw response in raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: forward to generate response + response = self.client.embeddings.create( + input=texts, + model=self.model_name, + **kwargs, + ) + # step5: return response + response_json = response.model_dump() + return response_json + + def embed_query(self, text: str) -> List[float]: + response = self([text]) + output = response["data"] + return output[0]["embedding"] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + response = self(texts) + output = response["data"] + return [emb["embedding"] for emb in output] + \ No newline at end of file diff --git a/muagent/models/qwen_model.py b/muagent/models/qwen_model.py new file mode 100644 index 0000000..2844bf1 --- /dev/null +++ b/muagent/models/qwen_model.py @@ -0,0 +1,461 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for DashScope models""" +import os +from abc import ABC +from http import HTTPStatus +from typing import ( + Any, + Union, + List, + Sequence, + Optional, + Generator, + Literal +) + +import openai +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from loguru import logger + +try: + import dashscope + + dashscope_version = dashscope.version.__version__ + if dashscope_version < "1.19.0": + logger.warning( + f"You are using 'dashscope' version {dashscope_version}, " + "which is below the recommended version 1.19.0. " + "Please consider upgrading to maintain compatibility.", + ) + from dashscope.api_entities.dashscope_response import GenerationResponse +except ImportError: + dashscope = None + GenerationResponse = None + + +from ..schemas import Message +from .base_model import ModelWrapperBase +from ..utils.common_utils import _convert_to_str + + + +class QwenWrapperBase(ModelWrapperBase, ABC): + """The model wrapper for DashScope API.""" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the DashScope wrapper. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in DashScope API. + api_key (`str`, default `None`): + The API key for DashScope API. + generate_args (`dict`, default `None`): + The extra keyword arguments used in DashScope api generation, + e.g. `temperature`, `seed`. + """ + if model_name is None: + model_name = config_name + logger.warning("model_name is not set, use config_name instead.") + + super().__init__(config_name=config_name, model_name=model_name) + + self.generate_args = generate_args or {} + + self.api_key = api_key + self.max_length = None + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> Union[List[dict], str]: + raise RuntimeError( + f"Model Wrapper [{type(self).__name__}] doesn't " + f"need to format the input. Please try to use the " + f"model wrapper directly.", + ) + + +class QwenChatWrapper(QwenWrapperBase): + """The model wrapper for DashScope's chat API, refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/quick-start?spm=a2c4g.11186623.0.0.7e346eb5RvirBw + + ```json + { + "status_code": 200, + "request_id": "a75a1b22-e512-957d-891b-37db858ae738", + "code": "", + "message": "", + "output": { + "text": null, + "finish_reason": null, + "choices": [ + { + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "xxx" + } + } + ] + }, + "usage": { + "input_tokens": 25, + "output_tokens": 77, + "total_tokens": 102 + } + } + ``` + """ + + model_type: str = "qwen_chat" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + api_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1", + stream: bool = False, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the DashScope wrapper. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in DashScope API. + api_key (`str`, default `None`): + The API key for DashScope API. + stream (`bool`, default `False`): + If True, the response will be a generator in the `stream` + field of the returned `ModelResponse` object. + generate_args (`dict`, default `None`): + The extra keyword arguments used in DashScope api generation, + e.g. `temperature`, `seed`. + """ + + super().__init__( + config_name=config_name, + model_name=model_name, + api_key=api_key, + generate_args=generate_args, + **kwargs, + ) + self.api_url = api_url or "https://dashscope.aliyuncs.com/compatible-mode/v1" + self.stream = stream + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.api_url) + + def __call__( + self, + prompt: str = None, + messages: Sequence[dict] = [], + tools: Sequence[object] = [], + *, + tool_choice: Optional[Literal['auto', 'required']] = None, + parallel_tool_calls: Optional[bool] = None, + stream: Optional[bool] = None, + stop: Optional[str] = '', + format_type: Literal['str', 'raw', 'dict'] = 'raw', + **kwargs + ) -> Generator[Union[ChatCompletionChunk, ChatCompletion], None, None]: + """Invoke the Yi Chat API by sending a list of messages.""" + + messages = [{"role": "user", "content": prompt}] if prompt else messages + # Checking messages + if not isinstance(messages, list): + raise ValueError( + f"Yi `messages` field expected type `list`, " + f"got `{type(messages)}` instead.", + ) + + if not all("role" in Message and "content" in Message for Message in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for Yi API.", + ) + # + + stream = stream or self.stream + model_name = self.model_name + # + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + kwargs.update( + { + "model": model_name, + "messages": messages, + "stream": stream, + "stop": stop + # "tools": tools, + # "tool_choice": tool_choice, + # "parallel_tool_calls": parallel_tool_calls, + }, + ) + if tools: + kwargs["tools"] = tools + + response = self.client.chat.completions.create(**kwargs) + if format_type == "str": + content = "" + if stream: + for chunk in response: + content += chunk.choices[0].delta.content or '' + yield content + else: + yield response.choices[0].message.content + else: + if stream: + for chunk in response: + yield chunk + else: + yield response + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> List[dict]: + """A common format strategy for chat models, which will format the + input messages into a user message. + + Note this strategy maybe not suitable for all scenarios, + and developers are encouraged to implement their own prompt + engineering strategies. + + The following is an example: + + .. code-block:: python + + prompt1 = model.format( + Message("system", "You're a helpful assistant", role="system"), + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + prompt2 = model.format( + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + The prompt will be as follows: + + .. code-block:: python + + # prompt1 + [ + { + "role": "system", + "content": "You're a helpful assistant" + }, + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + # prompt2 + [ + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + + Args: + args (`Union[Msg, Sequence[Msg]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages. + """ + + return ModelWrapperBase.format_for_common_chat_models(*args) + + + def format_prompt(self, *args: Union[Message, Sequence[Message]]) -> str: + """Forward the input to the model. + + Args: + args (`Union[Msg, Sequence[Msg]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + + Returns: + `str`: + The formatted string prompt. + """ + input_msgs: List[Message] = [] + for _ in args: + if _ is None: + continue + if isinstance(_, Message): + input_msgs.append(_) + elif isinstance(_, list) and all(isinstance(__, Message) for __ in _): + input_msgs.extend(_) + else: + raise TypeError( + f"The input should be a Msg object or a list " + f"of Msg objects, got {type(_)}.", + ) + + sys_prompt = None + dialogue = [] + for i, unit in enumerate(input_msgs): + if i == 0 and unit.role_type == "system": + # system prompt + sys_prompt = unit.content + else: + # Merge all messages into a conversation history prompt + dialogue.append( + f"{unit.role_name}: {unit.content}", + ) + + dialogue_history = "\n".join(dialogue) + + if sys_prompt is None: + prompt_template = "## Conversation History\n{dialogue_history}" + else: + prompt_template = ( + "{system_prompt}\n" + "\n" + "## Conversation History\n" + "{dialogue_history}" + ) + + return prompt_template.format( + system_prompt=sys_prompt, + dialogue_history=dialogue_history, + ) + +class QwenTextEmbeddingWrapper(QwenWrapperBase): + """The model wrapper for DashScope Text Embedding API. + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details?spm=a2c4g.11186623.0.i3 + + ```json + { + "status_code": 200, // 200 indicate success otherwise failed. + "request_id": "fd564688-43f7-9595-b986", // The request id. + "code": "", // If failed, the error code. + "message": "", // If failed, the error message. + "output": { + "embeddings": [ // embeddings + { + "embedding": [ // one embedding output + -3.8450357913970947, ..., + ], + "text_index": 0 // the input index. + } + ] + }, + "usage": { + "total_tokens": 3 // the request tokens. + } + } + ``` + """ + + model_type: str = "qwen_text_embedding" + + def __call__( + self, + texts: Union[list[str], str], + dimension: Literal[512, 768, 1024, 1536] = 768, + **kwargs: Any, + ): + """Embed the messages with DashScope Text Embedding API. + + Args: + texts (`list[str]` or `str`): + The messages used to embed. + **kwargs (`Any`): + The keyword arguments to DashScope Text Embedding API, + e.g. `text_type`. Please refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details-15 + for more detailed arguments. + + Returns: + `ModelResponse`: + A list of embeddings in embedding field and the raw + response in raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved + for `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and + check the response generated by the model, which takes the + response as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + # client = openai.OpenAI(api_key=self.api_key, base_url=self.api_url) + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: forward to generate response + response = dashscope.TextEmbedding.call( + input=texts, + model=self.model_name, + api_key=self.api_key, + dimension=dimension, + **kwargs, + ) + + if response.status_code != HTTPStatus.OK: + error_msg = ( + f" Request id: {response.request_id}," + f" Status code: {response.status_code}," + f" error code: {response.code}," + f" error message: {response.message}." + ) + raise RuntimeError(error_msg) + + # step5: return response + return response + + def embed_query(self, text: str) -> List[float]: + response = self([text]) + output = response["output"] + embeddings = output["embeddings"] + return embeddings[0]["embedding"] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + response = self(texts) + output = response["output"] + embeddings = output["embeddings"] + return [emb["embedding"] for emb in embeddings] + \ No newline at end of file diff --git a/muagent/models/yi_model.py b/muagent/models/yi_model.py new file mode 100644 index 0000000..4140baf --- /dev/null +++ b/muagent/models/yi_model.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for Yi models +The implementation of this _ModelWrapperMeta are borrowed from +https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py +""" + + +import json +from typing import ( + List, + Union, + Sequence, + Optional, + Generator, + Literal +) + +import openai +from openai.types.chat import ChatCompletion, ChatCompletionChunk + +from .base_model import ModelWrapperBase +from ..schemas import Message + + +class YiChatWrapper(ModelWrapperBase): + """The model wrapper for Yi Chat API. + + Response: + - From https://platform.lingyiwanwu.com/docs + + ```json + { + "id": "cmpl-ea89ae83", + "object": "chat.completion", + "created": 5785971, + "model": "yi-large-rag", + "usage": { + "completion_tokens": 113, + "prompt_tokens": 896, + "total_tokens": 1009 + }, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Today in Los Angeles, the weather ...", + }, + "finish_reason": "stop" + } + ] + } + ``` + """ + + model_type: str = "yi_chat" + + def __init__( + self, + config_name: str, + model_name: str, + api_key: str, + api_url: str="https://api.lingyiwanwu.com/v1", + max_tokens: Optional[int] = None, + top_p: float = 0.9, + temperature: float = 0.3, + stream: bool = False, + **kwargs, + ) -> None: + """Initialize the Yi chat model wrapper. + + Args: + config_name (`str`): + The name of the configuration to use. + model_name (`str`): + The name of the model to use, e.g. yi-large, yi-medium, etc. + api_key (`str`): + The API key for the Yi API. + max_tokens (`Optional[int]`, defaults to `None`): + The maximum number of tokens to generate, defaults to `None`. + top_p (`float`, defaults to `0.9`): + The randomness parameters in the range [0, 1]. + temperature (`float`, defaults to `0.3`): + The temperature parameter in the range [0, 2]. + stream (`bool`, defaults to `False`): + Whether to stream the response or not. + """ + + init_params = locals() + init_params.pop("self") + init_params["model_type"] = self.model_type + super().__init__(**init_params) + + if top_p > 1 or top_p < 0: + raise ValueError( + f"The `top_p` parameter must be in the range [0, 1], but got " + f"{top_p} instead.", + ) + + if temperature < 0 or temperature > 2: + raise ValueError( + f"The `temperature` parameter must be in the range [0, 2], " + f"but got {temperature} instead.", + ) + self.api_url = api_url or "https://api.lingyiwanwu.com/v1" + self.client = openai.OpenAI(api_key=self.api_key,base_url=self.api_url) + self.max_tokens = max_tokens + self.top_p = top_p + self.temperature = temperature + self.stream = stream + + def __call__( + self, + prompt: str = None, + messages: Sequence[dict] = [], + tools: Sequence[object] = [], + *, + tool_choice: Literal['auto', 'required'] = 'auto', + parallel_tool_calls: Optional[bool] = None, + stream: Optional[bool] = None, + stop: Optional[str] = '', + format_type: Literal['str', 'raw', 'dict'] = 'raw', + **kwargs + ) -> Generator[Union[ChatCompletionChunk, ChatCompletion, str], None, None]: + """Invoke the Yi Chat API by sending a list of messages.""" + + messages = [{"role": "user", "content": prompt}] if prompt else messages + # Checking messages + if not isinstance(messages, list): + raise ValueError( + f"Yi `messages` field expected type `list`, " + f"got `{type(messages)}` instead.", + ) + + if not all("role" in Message and "content" in Message for Message in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for Yi API.", + ) + # + stream = stream or self.stream + model_name = "yi-large-fc" if tools else self.model_name + # model_name = self.model_name + # + kwargs.update( + { + "model": model_name, + "messages": messages, + "stream": stream, + "tools": tools, + "tool_choice": tool_choice, + "parallel_tool_calls": parallel_tool_calls, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "stop": [stop] + }, + ) + + response = self.client.chat.completions.create(**kwargs) + + if format_type == "str": + content = "" + if stream: + for chunk in response: + content += chunk.choices[0].delta.content or '' + yield content + else: + yield response.choices[0].message.content + else: + if stream: + for chunk in response: + yield chunk + else: + yield response + + def function_call( + self, + messages: Optional[Sequence[dict]] = None, + tools: Sequence[object] = [], + *, + prompt: Optional[str] = None, + tool_choice: Literal['auto', 'required'] = 'auto', + parallel_tool_calls: Optional[bool] = None, + stream: Optional[bool] = False, + ) -> ChatCompletion: + """Call a function to process messages with optional tools. + + Args: + messages (Optional[Sequence[dict]], optional): A sequence of messages for context. + tools (Sequence[object], optional): Tools available for use. + prompt (Optional[str], optional): An optional prompt. + tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools. + parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls. + stream (Optional[bool], optional): If true, streams the output instead of returning it all at once. + + Returns: + Union[ChatCompletion, Mapping]: The result of the function call processed by the model. + """ + kwargs = locals() + kwargs.pop("self") + kwargs.pop("__class__") + return super().function_call(**kwargs) + + def function_call_stream( + self, + messages: Optional[Sequence[dict]] = None, + tools: Sequence[object] = [], + *, + prompt: Optional[str] = None, + tool_choice: Literal['auto', 'required'] = 'auto', + parallel_tool_calls: Optional[bool] = None, + stream: Optional[bool] = True, + ) -> Generator[ChatCompletionChunk, None, None]: + """Stream function call outputs. + + Args: + messages (Optional[Sequence[dict]], optional): A sequence of messages for context. + tools (Sequence[object], optional): Tools available for use. + prompt (Optional[str], optional): An optional prompt. + tool_choice (Optional[Literal['auto', 'required']], optional): How to select tools. + parallel_tool_calls (Optional[bool], optional): If true, allows parallel tool calls. + stream (Optional[bool], optional): If true, streams the output instead of returning it all at once. + + Yields: + Generator[Union[ChatCompletionChunk, Mapping], None, None]: A generator yielding parts of the function output. + """ + kwargs = locals() + kwargs.pop("self") + kwargs.pop("__class__") + for i in super().function_call_stream(**kwargs): yield i + + def format( + self, + *args: Union[Message, Sequence[Message]], + ) -> List[dict]: + """Format the messages into the required format of Yi Chat API. + + Note this strategy maybe not suitable for all scenarios, + and developers are encouraged to implement their own prompt + engineering strategies. + + The following is an example: + + .. code-block:: python + + prompt1 = model.format( + Message("system", "You're a helpful assistant", role="system"), + Message("Bob", "Hi, how can I help you?", role="assistant"), + Message("user", "What's the date today?", role="user") + ) + + The prompt will be as follows: + + .. code-block:: python + + # prompt1 + [ + { + "role": "system", + "content": "You're a helpful assistant" + }, + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + Args: + args (`Union[Message, Sequence[Message]]`): + The input arguments to be formatted, where each argument + should be a `Message` object, or a list of `Message` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages. + """ + + # TODO: Support Vision model + if self.model_name == "yi-vision": + raise NotImplementedError( + "Yi Vision model is not supported in the current version, " + "please format the messages manually.", + ) + + return ModelWrapperBase.format_for_common_chat_models(*args) \ No newline at end of file diff --git a/muagent/orm/__init__.py b/muagent/orm/__init__.py deleted file mode 100644 index 2a2c21b..0000000 --- a/muagent/orm/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from .db import _engine, Base -from loguru import logger - -__all__ = [ - -] - -def create_tables(): - Base.metadata.create_all(bind=_engine) - -def reset_tables(): - Base.metadata.drop_all(bind=_engine) - create_tables() - - -def check_tables_exist(table_name) -> bool: - table_exist = _engine.dialect.has_table(_engine.connect(), table_name, schema=None) - return table_exist - -def table_init(): - if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")) or \ - (not check_tables_exist ("code_base")): - create_tables() diff --git a/muagent/project_manager.py b/muagent/project_manager.py new file mode 100644 index 0000000..5f0250d --- /dev/null +++ b/muagent/project_manager.py @@ -0,0 +1,70 @@ +from typing import ( + Dict, + Optional, + Union +) +import os, sys, json, random +from loguru import logger + + +from .schemas.models import ModelConfig, LLMConfig +from .schemas import ProjectConfig, PromptConfig, AgentConfig + + +def get_project_config_from_env( + agent_configs: Optional[Dict[str, AgentConfig]] = None, + model_configs: Optional[Dict[str, Union[ModelConfig, LLMConfig]]] = None, + prompt_configs: Optional[Dict[str, PromptConfig]] = PromptConfig(), +) -> ProjectConfig: + """""" + init_dict = { + "model_configs": [model_configs, ModelConfig], + "agent_configs": [agent_configs, AgentConfig], + "prompt_configs": [prompt_configs, PromptConfig], + } + project_configs = { + "model_configs": None, + "agent_configs": None, + "prompt_configs": None, + } + for k, (v, _type) in init_dict.items(): + if v: + pass + elif k.upper() in os.environ: + v = json.loads(os.environ[k.upper()]) + vc = {} + for kk, vv in v.items(): + try: + vc[kk] = _type(**vv) + except: + vc[kk] = LLMConfig(**vv) + v = vc + if v: + chat_list = [_type for _type in v.keys() if "chat" in _type] + embedding_list = [_type for _type in v.keys() if "embedding" in _type] + if chat_list: + v["default_chat"] = v[random.choice(chat_list)] + model_type = random.choice(chat_list) + default_model_config = v[model_type] + os.environ["DEFAULT_MODEL_TYPE"] = model_type + os.environ["DEFAULT_MODEL_NAME"] = default_model_config.model_name + os.environ["DEFAULT_API_KEY"] = default_model_config.api_key or "" + os.environ["DEFAULT_API_URL"] = default_model_config.api_url or "" + if embedding_list: + v["default_embed"] = v[random.choice(embedding_list)] + model_type = random.choice(chat_list) + default_model_config = v[model_type] + os.environ["DEFAULT_EMBED_MODEL_TYPE"] = model_type + os.environ["DEFAULT_EMBED_MODEL_NAME"] = default_model_config.model_name + os.environ["DEFAULT_EMBED_API_KEY"] = default_model_config.api_key or "" + os.environ["DEFAULT_EMBED_API_URL"] = default_model_config.api_url or "" + project_configs[k] = v + else: + logger.warning( + f"Cant't init any {k} in this env." + ) + else: + logger.warning( + f"Cant't init any {k} in this env." + ) + return ProjectConfig(**project_configs) \ No newline at end of file diff --git a/muagent/prompt_manager/__init__.py b/muagent/prompt_manager/__init__.py new file mode 100644 index 0000000..00ea8cf --- /dev/null +++ b/muagent/prompt_manager/__init__.py @@ -0,0 +1,8 @@ +from .base_prompt_manager import BasePromptManager +from .common_prompt_manager import CommonPromptManager + + +__all__ = [ + "BasePromptManager", + "CommonPromptManager" +] \ No newline at end of file diff --git a/muagent/prompt_manager/base.py b/muagent/prompt_manager/base.py new file mode 100644 index 0000000..479f1eb --- /dev/null +++ b/muagent/prompt_manager/base.py @@ -0,0 +1,32 @@ +from .language.en import * +from .language.zh import * + + +TITLE_CONFIGS_LANGUAGE = { + "en": EN_TITLE_CONFIGS, + "zh": ZH_TITLE_CONFIGS, +} + +TITLE_EDGES_LANGUAGE = { + "en": EN_TITLE_EDGES, + "zh": ZH_TITLE_EDGES, +} + +TITLE_FORMAT_LANGUAGE = { + "en": EN_TITLE_FORMAT, + "zh": ZH_TITLE_FORMAT +} + +TITLE_LANGUAGE = { + "en": EN_TITLES, + "zh": ZH_TITLES +} + +ZERO_TITLES_LANGUAGE = { + "en": EN_ZERO_TITLES, + "zh": ZH_ZERO_TITLES, +} +COMMON_TEXT_LANGUAGE = { + "en": EN_COMMON_TEXT, + "zh": ZH_COMMON_TEXT +} \ No newline at end of file diff --git a/muagent/prompt_manager/base_prompt_manager.py b/muagent/prompt_manager/base_prompt_manager.py new file mode 100644 index 0000000..0c0c0c0 --- /dev/null +++ b/muagent/prompt_manager/base_prompt_manager.py @@ -0,0 +1,506 @@ +from abc import ABCMeta, abstractmethod +from typing import ( + Any, + Union, + Optional, + Type, + Literal, + Dict, + List, + Tuple, + Sequence, + Mapping +) +from pydantic import BaseModel +from loguru import logger +import os +import uuid +import copy + +from .base import * +from .util import edges_to_graph_with_cycle_detection +from ..sandbox import NBClientBox +from ..tools import get_tool +from ..schemas import Memory, Message, PromptConfig +from ..schemas.common import ActionStatus, LogVerboseEnum + +from muagent.base_configs.env_config import KB_ROOT_PATH + + +class _PromptManagerWrapperMeta(ABCMeta): + """A meta call to replace the prompt manager wrapper's __call__ function with + wrapper about error handling.""" + + def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any: + if "__call__" in attrs: + attrs["__call__"] = attrs["__call__"] + return super().__new__(mcs, name, bases, attrs) + + def __init__(cls, name: Any, bases: Any, attrs: Any) -> None: + if not hasattr(cls, "_registry"): + cls._registry = {} + cls._type_registry = {} + else: + cls._registry[name] = cls + if hasattr(cls, "pm_type"): + cls._type_registry[cls.pm_type] = cls + super().__init__(name, bases, attrs) + + +class BasePromptManager(metaclass=_PromptManagerWrapperMeta): + + pm_type: str = "BasePromptManager" + """The type of prompt manager.""" + + def __init__( + self, + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = "", + language: Literal["en", "zh"] = "en", + *, + monitored_agents=[], + monitored_fields=[], + log_verbose: str = "0", + workdir_path: str = KB_ROOT_PATH, + **kwargs + ): + # + self.system_prompt = system_prompt + self.input_template = input_template + self.output_template = output_template + self.prompt = prompt + self.language = language + # decrapted + self.monitored_agents = monitored_agents + self.monitored_fields = monitored_fields + # + self.extra_registry_titles: Dict = {} + self.extra_register_edges: Sequence = [] + self.new_dfsindex_to_str_format: Dict = {} + """use {title name} {description/function_value}""" + + # + self.codebox = NBClientBox(do_code_exe=True) # Initialize code execution box + self.workdir_path = workdir_path # Set the working directory path + self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose # Configure logging verbosity + + @classmethod + def from_config(self, prompt_config: PromptConfig, **kwargs) -> 'BasePromptManager': + """Get the prompt manager from PromptConfig""" + init_kwargs = {**kwargs, **prompt_config.dict()} + return self.get_wrapper(prompt_config.prompt_manager_type)(**init_kwargs) + + @classmethod + def get_wrapper(cls, prompt_manager_type: str) -> Type['BasePromptManager']: + """Get the specific PromptManager wrapper""" + if prompt_manager_type in cls._type_registry: + return cls._type_registry[prompt_manager_type] # type: ignore[return-value] + elif prompt_manager_type in cls._registry: + return cls._registry[prompt_manager_type] # type: ignore[return-value] + else: + raise KeyError( + f"Unsupported prompt_manager_type [{prompt_manager_type}]" + ) + + def register_graph( + self, + title_configs: Mapping[str, Mapping] = {}, + title_edges: Sequence[Sequence[str]] = {}, + title_format: Mapping[int, str] = {}, + titles: Mapping[str, Sequence[str]] = {}, + zero_titles: Mapping = {}, + common_texts: Mapping[str, str] = {} + ): + """transform title and edge into title graph to execute""" + # custom define + self.register_env( + title_configs, title_edges, title_format, titles, + zero_titles=zero_titles, + common_texts=common_texts + ) + self.register_prompt() + + # prepare title graph + start_nodes, self.title_graph = edges_to_graph_with_cycle_detection(self._registry_edges) + for title in start_nodes: + if title not in self._title_prefix + self._title_suffix: + self._title_middle.append(title) + + if LogVerboseEnum.le(LogVerboseEnum.Log3Level, os.environ.get("log_verbose", "0")): + logger.debug(f"{self._registry_titles}, {self._registry_edges}, {self.title_graph}") + + def register_env( + self, + title_configs: Mapping[str, Mapping] = {}, + title_edges: Sequence[Sequence[str]] = {}, + title_format: Mapping[int, str] = {}, + titles: Mapping[str, Sequence[str]] = {}, + *, + zero_titles: Mapping = {}, + common_texts: Mapping[str, str] = {} + ): + self._registry_titles = copy.deepcopy(title_configs) + self._registry_titles.update(self.extra_registry_titles) + self._registry_edges = copy.deepcopy(title_edges) + self._registry_edges.extend(self.extra_register_edges) + + self._dfsindex_to_str_format = copy.deepcopy(title_format) + self._dfsindex_to_str_format.update(self.new_dfsindex_to_str_format) + + self._title_prefix = titles.get("title_prefix", []) + self._title_suffix = titles.get("title_suffix", []) + self._title_middle = titles.get("title_middle", []) + + self._zero_titles = copy.deepcopy(zero_titles) # or ZERO_TITLES_LANGUAGE.get(self.language) + self._common_texts = copy.deepcopy(common_texts) or COMMON_TEXT_LANGUAGE.get(self.language) + + @abstractmethod + def register_prompt(self, ): + """register input/output/prompt into titles and edges""" + raise NotImplementedError( + f"Prompt Manager Wrapper [{type(self).__name__}]" + f" is missing the required `register_prompt`" + f" method.", + ) + + def pre_print(self, **kwargs) -> str: + kwargs.update({"is_pre_print": True}) + prompt = self.generate_prompt(**kwargs) + return prompt + + def generate_prompt(self, **kwargs) -> str: + '''force to print all prompt format whatever it has value''' + if self.prompt: + return self.prompt.format(**self.handler_prompt_values(**kwargs)) + + is_pre_print = kwargs.get("is_pre_print", False) + # update title's description and function_value + title_values = {} + for title, title_config in self._registry_titles.items(): + if hasattr(self, title_config["function"]): + + handler = getattr(self, title_config["function"]) + function_value = handler( + prompt=title_config.get("prompt", ""), title_key=title, **kwargs + ) if handler else None + else: + function_value = title_config["description"] + + title_values[title] = { + "description": title_config["description"], + "function_value": function_value, + "display_type": title_config["display_type"], + "str_template": title_config.get("str_template", ""), + "prompt": title_config.get("prompt", ""), + } + + # transform title values into 'markdown' prompt by title graph + prompt_values: List[str] = [] + prompt_values = self._process_title_values( + title_values, + title_type="description", + prompt_values=prompt_values, + is_pre_print=is_pre_print + ) + + transition_text = self._common_texts["transition_text"] + prompt_values.append(self._dfsindex_to_str_format[0].format(transition_text, "")) + + prompt_values = self._process_title_values( + title_values, + title_type="value", + prompt_values=prompt_values, + is_pre_print=is_pre_print + ) + + # logger.info(prompt_values) + reponse_text = self._common_texts["reponse_text"] + if not any("RESPONSE OUTPUT" in i for i in prompt_values): + prompt_values.append(reponse_text) + elif not any(["RESPONSE OUTPUT\n" in i for i in prompt_values]): + prompt_values.append(self._dfsindex_to_str_format[0].format("RESPONSE OUTPUT", "")) + # return prompt except '\n' in end + prompt_values = [pv.rstrip('\n') for pv in prompt_values] + return '\n\n'.join(prompt_values) + + def _process_title_values( + self, + title_values: Mapping[str, Mapping[str, Any]], + title_type: Literal["description", "value"], + prompt_values: Sequence[str] = [], + is_pre_print=False + ): + '''process title values to prompt''' + + def append_prompt_dfs(titles: Sequence[str], prompt_values: Sequence=[], dfs_index=0): + '''''' + if titles == [] or titles is None: return prompt_values + for title in titles: + title_value = title_values.get(title) + ctitles = self.title_graph.get(title, []) + ctitle_values = [ + ctitle + for ctitle in ctitles + if title_values.get(ctitle, {}).get('function_value') + ] + + str_template = title_value.get( + "str_template", self._dfsindex_to_str_format[dfs_index] + ) or self._dfsindex_to_str_format[dfs_index] + description = title_value["description"] + function_value = title_value["function_value"] + display_type = title_value["display_type"] + prompt = title_value["prompt"] + + # logger.info( + # f"title={title}, description={description}, function_value={function_value} \n" + # f"display_type={display_type}, str_template={str_template} \n" + # f"ctitles= {ctitles}, ctitle_values={ctitle_values}" + # ) + + # todo display_type==only_value + if title_type=="description": + if display_type == "title": + prompt_values.append(str_template.format(title, description or function_value)) + elif display_type=="description" and function_value: + prompt_values.append(str_template.format(title, function_value or description or prompt)) + elif display_type == "value" and (function_value or len(ctitle_values)>0): + prompt_values.append(str_template.format(title, description or function_value)) + elif display_type == "values" and len(ctitle_values)>0: + prompt_values.append(str_template.format(title, description or function_value)) + elif display_type == "must_value" and (description or function_value or len(ctitle_values)>0): + prompt_values.append(str_template.format(title, description or function_value)) + elif is_pre_print: + prompt_values.append(str_template.format(title, description or function_value)) + elif title_type=="value": + if display_type == "values" and len(ctitle_values)>0: + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), "")) + # must value + elif display_type == "must_value" and (function_value and len(ctitle_values)>0): + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value)) + continue + elif display_type == "must_value" and function_value: + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value)) + elif display_type == "must_value" and len(ctitle_values)>0: + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), "")) + # value + elif display_type == "value" and (function_value and len(ctitle_values)>0): + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value)) + elif display_type == "value" and function_value: + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value)) + elif display_type == "value" and len(ctitle_values)>0: + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), "")) + elif is_pre_print and display_type not in ["title", "description"]: + prompt_values.append(str_template.format(title.replace(' FORMAT', ''), function_value)) + + prompt_values = append_prompt_dfs(ctitles, prompt_values, dfs_index+1) + + return prompt_values + + start_titles = self._title_prefix + self._title_middle + self._title_suffix + return append_prompt_dfs(start_titles, prompt_values) + + def parser(self, message: Message) -> Message: + '''parse llm output into dict''' + return message + + def step_router( + self, + msg: Message, + session_index: str = "", + **kwargs + ) -> Tuple[Message, ...]: + """Route a message to the appropriate step for processing based on its action status. + + Args: + msg (Message): The input message that needs processing. + session_index (str): The session identifier for managing the conversation. + **kwargs: Additional parameters for processing. + + Returns: + Tuple[Message, ...]: The processed message and any observation message. + """ + session_index = msg.session_index or session_index or str(uuid.uuid4()) + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"message.action_status: {msg.action_status}") + + observation_msg = None + # Determine the action to take based on the message's action status + if msg.action_status == ActionStatus.CODE_EXECUTING: + msg, observation_msg = self.code_step(msg, session_index) + elif msg.action_status == ActionStatus.TOOL_USING: + msg, observation_msg = self.tool_step(msg, session_index, **kwargs) + elif msg.action_status == ActionStatus.CODING2FILE: + self.save_code2file(msg, self.workdir_path) + # Handle other action statuses as needed (currently no operations for these) + elif msg.action_status == ActionStatus.CODE_RETRIEVAL: + pass + elif msg.action_status == ActionStatus.CODING: + pass + + return msg, observation_msg + + def code_step(self, msg: Message, session_index: str) -> Message: + """Execute code contained in the message. + + Args: + msg (Message): The message containing code to be executed. + session_index (str): The session identifier for managing the conversation. + + Returns: + Tuple[Message, Message]: The processed message and an observation message regarding code execution. + """ + # Execute the code using the codebox and capture the result + code_key = "code_content" + code_content = msg.spec_parsed_content.get(code_key, "") + code_answer = self.codebox.chat( + '```python\n{}```'.format(code_content) + ) + + # Prepare a response message based on code execution result + observation_title = { + "error": "The return error after executing the above code is {code_answer},need to recover.\n", + "accurate": "The return information after executing the above code is {code_answer}.\n", + "figure": "The return figure name is {uid} after executing the above code.\n" + } + code_prompt = ( + observation_title["error"].format(code_answer=code_answer.code_exe_response) + if code_answer.code_exe_type == "error" else + observation_title["accurate"].format(code_answer=code_answer.code_exe_response) + ) + + # Create an observation message for logging code execution outcome + observation_msg = Message( + session_index=session_index, + role_name="function", + role_type="observation", + input_text=code_content, + ) + + uid = str(uuid.uuid1()) # Generate a unique identifier for related content + if code_answer.code_exe_type == "image/png": + # If the code execution produces an image, log the result and update the message + msg.global_kwargs[uid] = code_answer.code_exe_response + msg.step_content += "\n**Observation:**: " + observation_title["figure"].format(uid=uid) + msg.parsed_contents.append({"Observation": observation_title["figure"].format(uid=uid)}) + observation_msg.update_content("\n**Observation:**: " + observation_title["figure"].format(uid=uid)) + observation_msg.update_parsed_content({"Observation": observation_title["figure"].format(uid=uid)}) + else: + # Log the standard execution result + msg.step_content += f"\n**Observation:**: {code_prompt}\n" + observation_msg.update_content(code_prompt) + observation_msg.update_parsed_content({"Observation": f"{code_prompt}\n"}) + + # Log the observations at the defined verbosity level + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"Code Observation: {msg.action_status}, {observation_msg.content}") + + return msg, observation_msg + + def tool_step( + self, + msg: Message, + session_index: str, + **kwargs + ) -> Message: + """Execute a tool based on parameters in the message. + + Args: + msg (Message): The message that specifies the tool to be executed. + session_index (str): The session identifier for managing the conversation. + **kwargs: Additional parameters for processing, including available tools. + + Returns: + Tuple[Message, ...]: + The processed message and an observation message regarding the tool execution. + """ + observation_title = { + "error": "there is no tool can execute.\n", + "accurate": "", + "figure": "The return figure name is {uid} after executing the above code.\n" + } + no_tool_msg = "\n**Observation:** there is no tool can execute.\n" # Message for missing tool + tool_names = kwargs.get("tools") # Retrieve available tool names + extra_params = kwargs.get("extra_params", {}) + tool_param = msg.spec_parsed_content.get("tool_param", {}) # Parameters for the tool execution + tool_param.update(extra_params) + tool_name = msg.spec_parsed_content.get("tool_name", "") # Name of the tool to execute + + # Create a message to log the tool execution result + observation_msg = Message( + session_index=session_index, + role_name="function", + role_type="observation", + input_text=str(tool_param), + ) + if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose): + logger.debug(f"message: {msg.action_status}, {tool_param}") + + if tool_name not in tool_names: + msg.step_content += f"\n{no_tool_msg}" + observation_msg.update_content(no_tool_msg) + observation_msg.update_parsed_content({"Observation": no_tool_msg}) + else: + # Execute the specified tool and capture the result + tool = get_tool(tool_name) + tool_res = tool.run(**tool_param) + msg.step_content += f"\n**Observation:** {tool_res}.\n" + msg.parsed_contents.append({"Observation": f"{tool_res}.\n"}) + observation_msg.update_content(f"**Observation:** {tool_res}.\n") + observation_msg.update_parsed_content({"Observation": f"{tool_res}.\n"}) + + # Log the observations at the defined verbosity level + if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose): + logger.info(f"**Observation:** {msg.action_status}, {observation_msg.content}") + + return msg, observation_msg + + def save_code2file(self, msg: Message, project_dir="./"): + """Save the code from the message to a specified file. + + Args: + msg (Message): The message containing the code to be saved. + project_dir (str): Directory path where the code file will be saved. + """ + filename = msg.parsed_content.get("SaveFileName") # Retrieve filename from message content + code = msg.spec_parsed_content.get("code") # Extract code content from the message + + # Replace HTML entities in the code + for k, v in {">": ">", "≥": ">=", "<": "<", "≤": "<="}.items(): + code = code.replace(k, v) + + project_dir_path = os.path.join(self.workdir_path, project_dir) # Construct project directory path + file_path = os.path.join(project_dir_path, filename) # Full path for the output code file + + # Create directories if they don't exist + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # Write the code to the file + with open(file_path, "w") as f: + f.write(code) + + def handler_prompt_values(self, **kwargs) -> Mapping[str, str]: + """Handling prompt values from memory, message' global or content + or step content or spec parsed content + """ + raise NotImplementedError( + f"Prompt Manager Wrapper [{type(self).__name__}]" + f" is missing the required `handler_prompt_values`" + f" method.", + ) + + def handle_empty_key(self, **kwargs) -> str: + '''return "" ''' + return "" + + def handler_input_key(self, **kwargs) -> str: + '''return {input_template}''' + return self.input_template + + def handler_output_key(self, **kwargs) -> str: + '''return {output_template}''' + return self.output_template + \ No newline at end of file diff --git a/muagent/prompt_manager/common_prompt_manager.py b/muagent/prompt_manager/common_prompt_manager.py new file mode 100644 index 0000000..feb628f --- /dev/null +++ b/muagent/prompt_manager/common_prompt_manager.py @@ -0,0 +1,320 @@ +from typing import ( + List, + Any, + Union, + Optional, + Literal +) +import copy +from pydantic import BaseModel +import random +from textwrap import dedent +from loguru import logger +import json + +from .base import * +from .base_prompt_manager import BasePromptManager +from ..schemas import Memory, Message, PromptConfig +from ..tools import get_tool, BaseToolModel + +from muagent.connector.utils import * + + +class CommonPromptManager(BasePromptManager): + """Prompt Manager of MarkDown style""" + + pm_type: str = "CommonPromptManager" + """The type of prompt manager.""" + + def __init__( + self, + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = "", + language: Literal["en", "zh"] = "en", + *, + extra_registry_titles: Dict = {}, + extra_register_edges: List = [], + new_dfsindex_to_str_format: Dict = {}, + monitored_agents=[], + monitored_fields=[], + **kwargs + ): + super().__init__( + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template, + prompt=prompt, + language=language, + monitored_agents=monitored_agents, + monitored_fields=monitored_fields, + **kwargs + ) + + # update new titles + self.extra_registry_titles: Dict = extra_registry_titles + self.extra_register_edges: List = extra_register_edges + self.new_dfsindex_to_str_format: Dict = new_dfsindex_to_str_format + + # + self.register_graph( + TITLE_CONFIGS_LANGUAGE[self.language], + TITLE_EDGES_LANGUAGE[self.language], + TITLE_FORMAT_LANGUAGE[self.language], + titles=TITLE_LANGUAGE[self.language], + zero_titles=ZERO_TITLES_LANGUAGE[self.language], + common_texts=COMMON_TEXT_LANGUAGE[self.language], + ) + + def register_prompt(self, ): + """register input/output/prompt into titles and edges""" + input_str, output_str = "", "" + input_values, output_values = {}, {} + + if self.system_prompt: + input_str = extract_section( + self.system_prompt, + self._zero_titles["input"] + ) + output_str = extract_section( + self.system_prompt, + self._zero_titles["output"] + ) + + input_values = parse_section_to_dict( + self.system_prompt, + self._zero_titles["input"] + ) + output_values = parse_section_to_dict( + self.system_prompt, + self._zero_titles["output"] + ) + self.system_prompt = extract_section( + self.system_prompt, + self._zero_titles["agent"] + ) or self.system_prompt + + if self.input_template: + input_values = parse_section_to_dict( + self.input_template, + self._zero_titles["input"] + ) or input_values + + self.input_template = extract_section( + self.input_template, + self._zero_titles["input"] + ) or input_str + + if self.output_template: + output_values = parse_section_to_dict( + self.output_template, + self._zero_titles["output"] + ) or output_values + self.output_template = extract_section( + self.output_template, + self._zero_titles["output"] + ) or output_str + # + self._registry_titles[self._zero_titles["input"]].update({ + "description": self.input_template or input_str, + }) + + self._registry_titles[self._zero_titles["output"]].update({ + "description": self.output_template or output_str, + }) + self._registry_titles.update( + {k: { + "description": v, + "function": "handle_custom_data", + "display_type": "value", + "str_template": "**{}:** {}", + } + for k,v in (input_values|output_values).items()} + ) + self._registry_edges.extend( + [(self._zero_titles["output"], k) for k in input_values.keys()] + ) + self._registry_edges.extend( + [(self._zero_titles["output"], k) for k in output_values.keys()] + ) + + def pre_print(self, **kwargs): + kwargs.update({"is_pre_print": True}) + prompt = self.generate_prompt(**kwargs) + + input_keys = parse_section(self.system_prompt, self._zero_titles["output"]) + llm_predict = "\n".join([f"**{k}:**" for k in input_keys]) + return_prompt = ( + f"{prompt}\n\n" + f"{'#'*19}" + "\n<<<>>>\n" + f"{'#'*19}" + f"\n\n{llm_predict}\n" + ) + return return_prompt + + def parser(self, message: Message) -> Message: + '''parse llm output into dict''' + content = message.content + # parse start + parsed_dict = parse_text_to_dict(content) + spec_parsed_dict = parse_dict_to_dict(parsed_dict) + # select parse value + action_value = parsed_dict.get('Action Status') + if action_value: + action_value = action_value.lower() + + code_content_value = spec_parsed_dict.get('python') or \ + spec_parsed_dict.get('java') + if action_value == 'tool_using': + tool_params_value = spec_parsed_dict.get('json') + else: + tool_params_value = {} + + # add parse value to message + message.action_status = action_value or "default" + spec_parsed_dict["code_content"] = code_content_value + spec_parsed_dict["tool_param"] = tool_params_value.get("tool_params") + spec_parsed_dict["tool_name"] = tool_params_value.get("tool_name") + # + message.update_parsed_content(parsed_dict) + message.update_spec_parsed_content(spec_parsed_dict) + return message + + def handler_prompt_values(self, **kwargs) -> Dict[str, str]: + memory: Memory = kwargs.get("memory", None) + query: Message = kwargs.get("query", None) + result = { + "query": query.content or query.input_text if query else "", + "memory": memory.to_format_messages(format_type="str") + } + return result + + def handle_custom_data(self, **kwargs): + '''get key-value from parsed_output_list or global_kargs''' + key: str = kwargs.get("title_key", "") + query: Message = kwargs.get('query') + + keys = [ + "_".join([i.title() for i in key.split(" ")]), + " ".join([i.title() for i in key.split("_")]), + key + ] + keys = list(set(keys)) + + content = "" + for key in keys: + if key in query.spec_parsed_content: + content = query.spec_parsed_content.get(key) + content = "\n".join(content) if isinstance(content, list) else content + break + if key in query.global_kwargs: + content = query.global_kwargs.get(key) + content = "\n".join(content) if isinstance(content, list) else content + break + + return content + + def handle_tool_data(self, **kwargs): + if 'tools' not in kwargs: return "" + + tools: List = kwargs.get('tools') + prompt: str = kwargs.get('prompt') + tools: List[BaseToolModel] = [get_tool(tool) for tool in tools if isinstance(tool, str)] + + if len(tools) == 0: return "" + + tool_strings = [] + for tool in tools: + args_str = f'args: {str(tool.intput_to_json_schema())}' if tool.ToolInputArgs else "" + tool_strings.append(f"{tool.name}: {tool.description}, {args_str}") + formatted_tools = "\n".join(tool_strings) + + tool_names = ", ".join([tool.name for tool in tools]) + + tool_prompt = dedent(prompt.format(formatted_tools=formatted_tools, tool_names=tool_names)) + while "\n " in tool_prompt: + tool_prompt = tool_prompt.replace("\n ", "\n") + + return tool_prompt + + def handle_agent_data(self, **kwargs): + """""" + if 'agent_names' not in kwargs or "agent_descs" not in kwargs: + return "" + + agent_names: List = kwargs.get('agent_names') + agent_descs: List = kwargs.get('agent_descs') + prompt: str = kwargs.get('prompt') + + if len(agent_names) == 0: return "" + + random.shuffle(agent_names) + agent_descriptions = [] + for agent_name, desc in zip(agent_names, agent_descs): + while "\n\n" in desc: + desc = desc.replace("\n\n", "\n") + desc = desc.replace("\n", ",") + agent_descriptions.append( + f'"role name: {agent_name}\nrole description: {desc}"' + ) + + agent_description = "\n".join(agent_descriptions) + agent_prompt = dedent( + prompt.format(agents=agent_description, agent_names=agent_names) + ) + + while "\n " in agent_prompt: + agent_prompt = agent_prompt.replace("\n ", "\n") + + return agent_prompt + + def handle_current_query(self, **kwargs) -> str: + """""" + query: Message = kwargs.get('query') + if query: + return query.input_text + return "" + + def handle_session_records(self, **kwargs) -> str: + + memory: Memory = kwargs.get('memory', Memory(messages=[])) + return memory.to_format_messages( + content_key='parsed_contents', + format_type='str', + with_tag=True + ) + + def handle_agent_profile(self, **kwargs) -> str: + return extract_section(self.system_prompt, 'AGENT PROFILE') or self.system_prompt + + def handle_output_format(self, **kwargs) -> str: + return extract_section(self.system_prompt, self._zero_titles["output"]) + + def handle_react_memory(self, **kwargs) -> str: + react_memory: Memory = kwargs.get('react_memory') + + if react_memory: + return react_memory.to_format_messages(format_type="str") + return "" + + def handle_task_memory(self, **kwargs) -> str: + if 'task_memory' not in kwargs: + return "" + + task_memory: Memory = kwargs.get('task_memory', Memory(messages=[])) + if task_memory is None: + return "" + + return "\n".join([ + "\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) + for _dict in task_memory.get_memory_values("parsed_content") + ]) + + def handle_current_plan(self, **kwargs) -> str: + if 'query' not in kwargs: + return "" + query: Message = kwargs['query'] + return query.global_kwargs.get("CURRENT_STEP", "") + \ No newline at end of file diff --git a/muagent/prompt_manager/language/en.py b/muagent/prompt_manager/language/en.py new file mode 100644 index 0000000..7fca273 --- /dev/null +++ b/muagent/prompt_manager/language/en.py @@ -0,0 +1,89 @@ +EN_TITLE_EDGES = [ + ("AGENT PROFILE", "ROLE"), + ("AGENT PROFILE", "AGENT INFORMATION"), + ("AGENT PROFILE", "TOOL INFORMATION"), + ("CONTEXT FORMAT", "SESSION RECORDS"), + ("CONTEXT FORMAT", "CURRENT QUERY"), +] + +EN_TITLE_CONFIGS = { + "AGENT PROFILE": { + "description": "", + "function": "handle_empty_key", + "display_type": "title" + }, + "CONTEXT FORMAT": { + "description": "Use the content provided in the context.", + "function": "handle_empty_key", + "display_type": "values" + }, + "INPUT FORMAT": { + "description": "", + "function": "handle_empty_key", + "display_type": "values" + }, + "RESPONSE OUTPUT FORMAT": { + "description": "", + "function": "handle_react_memory", + "display_type": "must_value" + }, + "ROLE": { + "description": "", + "prompt": "", + "function": "handle_agent_profile", + "display_type": "description" + }, + "TOOL INFORMATION": { + "description": "", + "prompt": """Below is a list of tools that are available for your use:{formatted_tools}\nvalid "tool_name" value is:\n{tool_names}""", + "function": "handle_tool_data", + "display_type": "description" + }, + "AGENT INFORMATION": { + "description": "", + "prompt": '''Please ensure your selection is one of the listed roles. Available roles for selection:\n{agents}Please ensure select the Role from agent names, such as {agent_names}''', + "function": "handle_agent_data", + "display_type": "description" + }, + "SESSION RECORDS": { + "description": "In this part, we will supply with the context about this question.", + "function": "handle_session_records", + "display_type": "value" + }, + "CURRENT QUERY": { + "description": "In this part, we will supply with current question to do.", + "function": "handle_current_query", + "display_type": "value" + }, +} + + + + +EN_TITLE_FORMAT = { + 0: "#### {}\n{}", + 1: "### {}\n{}", + 2: "## {}\n{}", + 3: "# {}\n{}", +} + + +EN_ZERO_TITLES = { + "agent": "AGENT PROFILE", + "context": "CONTEXT FORMAT", + "input": "INPUT FORMAT", + "output": "RESPONSE OUTPUT FORMAT" +} + + +EN_TITLES = { + "title_prefix": [EN_ZERO_TITLES["agent"], EN_ZERO_TITLES["context"]], + "title_suffix": [EN_ZERO_TITLES["input"], EN_ZERO_TITLES["output"]], + "title_middle": [], +} + + +EN_COMMON_TEXT = { + "transition_text": "BEGIN!!!", + "reponse_text": "Please response:" +} diff --git a/muagent/prompt_manager/language/zh.py b/muagent/prompt_manager/language/zh.py new file mode 100644 index 0000000..cfaf427 --- /dev/null +++ b/muagent/prompt_manager/language/zh.py @@ -0,0 +1,87 @@ +ZH_TITLE_EDGES = [ + ("智能体配置", "角色"), + ("智能体配置", "智能体信息"), + ("智能体配置", "工具信息"), + ("上下文", "会话记录"), + ("上下文", "当前问题"), +] + +ZH_TITLE_CONFIGS = { + "智能体配置": { + "description": "", + "function": "handle_empty_key", + "display_type": "title" + }, + "上下文": { + "description": "使用下面内容作为上下文的信息。", + "function": "handle_empty_key", + "display_type": "values" + }, + "输入": { + "description": "", + "function": "handle_empty_key", + "display_type": "values" + }, + "输出": { + "description": "", + "function": "handle_react_memory", + "display_type": "must_value" + }, + "角色": { + "description": "", + "prompt": "", + "function": "handle_agent_profile", + "display_type": "description" + }, + "工具信息": { + "description": "", + "prompt": """以下是您可以使用的工具列表:{formatted_tools}\n有效的 "tool_name" 值是:\n{tool_names}""", + "function": "handle_tool_data", + "display_type": "description" + }, + "智能体信息": { + "description": "", + "prompt": '''请确保您的选择是列出的角色之一。可供选择的角色有:\n{agents}请确保从代理名称中选择角色,例如 {agent_names}''', + "function": "handle_agent_data", + "display_type": "description" + }, + "会话记录": { + "description": "在这个部分,我们将提供有关这个问题的上下文。", + "function": "handle_session_records", + "display_type": "value" + }, + "当前问题": { + "description": "在这个部分,我们将提供当前需要处理的问题。", + "function": "handle_current_query", + "display_type": "value" + }, +} + + + + +ZH_TITLE_FORMAT = { + 0: "#### {}\n{}", + 1: "### {}\n{}", + 2: "## {}\n{}", + 3: "# {}\n{}", +} + +ZH_ZERO_TITLES = { + "agent": "智能体配置", + "context": "上下文", + "input": "输入", + "output": "输出" +} + +ZH_TITLES = { + "title_prefix": [ZH_ZERO_TITLES["agent"], ZH_ZERO_TITLES["context"]], + "title_suffix": [ZH_ZERO_TITLES["input"], ZH_ZERO_TITLES["output"]], + "title_middle": [], +} + + +ZH_COMMON_TEXT = { + "transition_text": "开始", + "reponse_text": "请回答:" +} diff --git a/muagent/prompt_manager/util.py b/muagent/prompt_manager/util.py new file mode 100644 index 0000000..0885b1d --- /dev/null +++ b/muagent/prompt_manager/util.py @@ -0,0 +1,94 @@ +from collections import defaultdict + + + +class GraphCycleError(Exception): + """Custom exception for graph cycle detection.""" + pass + + + +def edges_to_graph_with_cycle_detection(intervals): + """Converts a list of intervals into a directed graph and checks for cycles. + + Args: + intervals (list of tuple): List of intervals where each interval is defined by (start, end). + + Returns: + tuple: A tuple containing a list of start nodes (nodes with indegree of 0) and the constructed graph. + + Raises: + GraphCycleError: If the graph contains a cycle. + """ + + graph = defaultdict(list) # Adjacency list for the graph + indegree = defaultdict(int) # Count of incoming edges for each node + + # Build the graph and the indegree table + for start, end in intervals: + graph[start].append(end) # Add directed edge from start to end + indegree[end] += 1 # Increment indegree of end node + # Ensure every node is in the graph (even nodes without outgoing edges) + if start not in indegree: + indegree[start] = 0 # Initialize indegree for start node + + # Find all starting nodes (indegree of 0) + start_nodes = [node for node in indegree if indegree[node] == 0] + + # Detect cycle in the graph + if detect_cycle(graph): + raise GraphCycleError("Graph contains a cycle!") # Raise error if cycle is found + + return start_nodes, graph + + + +def detect_cycle(graph): + """Detects if a directed graph contains a cycle using DFS. + + Args: + graph (dict): The adjacency list of the graph. + + Returns: + bool: True if a cycle is detected, False otherwise. + """ + + visited = set() # To keep track of visited nodes + rec_stack = set() # To keep track of nodes currently in the recursion stack + + def dfs(node): + """Performs a DFS on the graph to detect cycles. + + Args: + node: Current node being visited. + + Returns: + bool: True if a cycle is detected. + """ + # If node is in recursion stack, a cycle is found + if node in rec_stack: + return True + # If node is already visited, no need to check it again + if node in visited: + return False + + # Mark the current node as visited and add to recursion stack + visited.add(node) + rec_stack.add(node) + + # Use list() to copy neighbors to avoid modifying while iterating + for neighbor in list(graph[node]): + if dfs(neighbor): # Recursive call for each neighbor + return True # Cycle detected in the neighbor + + # Remove the node from the recursion stack after visiting + rec_stack.remove(node) + return False # No cycle detected in this path + + # Iterate over each node in the graph to detect cycles + for node in list(graph.keys()): + if node not in visited: # Proceed if the node hasn't been visited yet + if dfs(node): # Start DFS + return True # Cycle found + + return False # No cycles found in the graph \ No newline at end of file diff --git a/muagent/sandbox/__init__.py b/muagent/sandbox/__init__.py index 435da9b..ac51b6b 100644 --- a/muagent/sandbox/__init__.py +++ b/muagent/sandbox/__init__.py @@ -1,6 +1,7 @@ from .basebox import CodeBoxResponse from .pycodebox import PyCodeBox +from .nbclient import NBClientBox, NoteBookExecutor __all__ = [ - "CodeBoxResponse", "PyCodeBox" + "CodeBoxResponse", "PyCodeBox", "NBClientBox" ] \ No newline at end of file diff --git a/muagent/sandbox/nbclient.py b/muagent/sandbox/nbclient.py new file mode 100644 index 0000000..c52ce01 --- /dev/null +++ b/muagent/sandbox/nbclient.py @@ -0,0 +1,297 @@ +"""Service for executing jupyter notebooks interactively +Partially referenced the implementation of +https://github.com/modelscope/agentscope/blob/main/src/agentscope/service/execute_code/exec_notebook.py +""" +import base64 +import asyncio +from loguru import logger + +try: + import nbclient + import nbformat +except ImportError: + nbclient = None + nbformat = None + + +import os, asyncio, re +from typing import List, Optional +from loguru import logger +from ..base_configs.env_config import KB_ROOT_PATH +from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus + + + +class NoteBookExecutor: + """ + Class for executing jupyter notebooks block interactively. + To use the service function, you should first init the class, then call the + run_code_on_notebook function. + + Example: + + ```ipython + from agentscope.service.service_toolkit import * + from agentscope.service.execute_code.exec_notebook import * + nbe = NoteBookExecutor() + code = "print('helloworld')" + # calling directly + nbe.run_code_on_notebook(code) + + >>> Executing function run_code_on_notebook with arguments: + >>> code: print('helloworld') + >>> END + + # calling with service toolkit + service_toolkit = ServiceToolkit() + service_toolkit.add(nbe.run_code_on_notebook) + input_obs = [{"name": "run_code_on_notebook", "arguments":{"code": code}}] + res_of_string_input = service_toolkit.parse_and_call_func(input_obs) + + "1. Execute function run_code_on_notebook\n [ARGUMENTS]:\n code: print('helloworld')\n [STATUS]: SUCCESS\n [RESULT]: ['helloworld\\n']\n" + + ``` + """ # noqa + + def __init__( + self, + timeout: int = 300, + work_path: str = KB_ROOT_PATH, + ) -> None: + """ + The construct function of the NoteBookExecutor. + Args: + timeout (Optional`int`): + The timeout for each cell execution. + Default to 300. + """ + + if nbclient is None or nbformat is None: + raise ImportError( + "The package nbclient or nbformat is not found. Please " + "install it by `pip install notebook nbclient nbformat`", + ) + + self.nb = nbformat.v4.new_notebook() + self.nb_client = nbclient.NotebookClient(nb=self.nb) + self.work_path = work_path + self.ori_path = os.getcwd() + self.timeout = timeout + + asyncio.run(self._start_client()) + + def _output_parser(self, output: dict) -> str: + """Parse the output of the notebook cell and return str""" + if output["output_type"] == "stream": + return output["text"] + elif output["output_type"] == "execute_result": + return output["data"]["text/plain"] + elif output["output_type"] == "display_data": + if "image/png" in output["data"]: + file_path = self._save_image(output["data"]["image/png"]) + return f"Displayed image saved to {file_path}" + else: + return "Unsupported display type" + elif output["output_type"] == "error": + return output["traceback"] + else: + logger.info(f"Unsupported output encountered: {output}") + return "Unsupported output encountered" + + async def _start_client(self) -> None: + """start notebook client""" + if self.nb_client.kc is None or not await self.nb_client.kc.is_alive(): + os.chdir(self.work_path) + self.nb_client.create_kernel_manager() + self.nb_client.start_new_kernel() + self.nb_client.start_new_kernel_client() + os.chdir(self.ori_path) + + async def _kill_client(self) -> None: + """kill notebook client""" + if ( + self.nb_client.km is not None + and await self.nb_client.km.is_alive() + ): + await self.nb_client.km.shutdown_kernel(now=True) + await self.nb_client.km.cleanup_resources() + + self.nb_client.kc.stop_channels() + self.nb_client.kc = None + self.nb_client.km = None + + async def _restart_client(self) -> None: + """Restart the notebook client""" + await self._kill_client() + self.nb_client = nbclient.NotebookClient(self.nb, timeout=self.timeout) + await self._start_client() + + async def _run_cell(self, cell_index: int): + """Run a cell in the notebook by its index""" + try: + self.nb_client.execute_cell(self.nb.cells[cell_index], cell_index) + return self.nb.cells[cell_index].outputs + return [self._output_parser(output) for output in self.nb.cells[cell_index].outputs] + except nbclient.exceptions.DeadKernelError: + await self.reset_notebook() + return "DeadKernelError when executing cell, reset kernel" + except nbclient.exceptions.CellTimeoutError: + assert self.nb_client.km is not None + await self.nb_client.km.interrupt_kernel() + return ( + "CellTimeoutError when executing cell" + ", code execution timeout" + ) + except Exception as e: + return str(e) + + @property + def cells_length(self) -> int: + """return cell length""" + return len(self.nb.cells) + + async def async_run_code_on_notebook(self, code: str): + """ + Run the code on interactive notebook + """ + self.nb.cells.append(nbformat.v4.new_code_cell(code)) + cell_index = self.cells_length - 1 + return await self._run_cell(cell_index) + + def run_code_on_notebook(self, code: str): + """ + Run the code on interactive jupyter notebook. + + Args: + code (`str`): + The Python code to be executed in the interactive notebook. + + Returns: + `ServiceResponse`: whether the code execution was successful, + and the output of the code execution. + """ + return asyncio.run(self.async_run_code_on_notebook(code)) + + def reset_notebook(self) -> str: + """ + Reset the notebook + """ + asyncio.run(self._restart_client()) + return "Reset notebook" + + + + + +class NBClientBox(BaseBox): + + enter_status: bool = False + + def __init__( + self, + do_code_exe: bool = False, + work_path: str = KB_ROOT_PATH, + ): + self.nbe = NoteBookExecutor(work_path=work_path) + self.do_code_exe = do_code_exe + + def decode_code_from_text(self, text: str) -> str: + pattern = r'```.*?```' + code_blocks = re.findall(pattern, text, re.DOTALL) + code_text: str = "\n".join([block.strip('`') for block in code_blocks]) + code_text = code_text[6:] if code_text.startswith("python") else code_text + code_text = code_text.replace("python\n", "").replace("code", "") + return code_text + + def run( + self, code_text: Optional[str] = None, + file_path: Optional[os.PathLike] = None, + retry = 3, + ) -> CodeBoxResponse: + if not code_text and not file_path: + return CodeBoxResponse( + code_exe_response="Code or file_path must be specifieds!", + code_text=code_text, + code_exe_type="text", + code_exe_status=502, + do_code_exe=self.do_code_exe, + ) + + if code_text and file_path: + return CodeBoxResponse( + code_exe_response="Can only specify code or the file to read_from!", + code_text=code_text, + code_exe_type="text", + code_exe_status=502, + do_code_exe=self.do_code_exe, + ) + + if file_path: + with open(file_path, "r", encoding="utf-8") as f: + code_text = f.read() + + + def _output_parser(output: dict) -> str: + """Parse the output of the notebook cell and return str""" + if output["output_type"] == "stream": + return CodeBoxResponse( + code_exe_type="text", + code_text=code_text, + code_exe_response=output["text"] or "Code run successfully (no output)", + code_exe_status=200, + do_code_exe=self.do_code_exe + ) + elif output["output_type"] == "execute_result": + return CodeBoxResponse( + code_exe_type="text", + code_text=code_text, + code_exe_response=output["data"]["text/plain"] or "Code run successfully (no output)", + code_exe_status=200, + do_code_exe=self.do_code_exe + ) + elif output["output_type"] == "display_data": + if "image/png" in output["data"]: + return CodeBoxResponse( + code_exe_type="image/png", + code_text=code_text, + code_exe_response=output["data"]["image/png"], + code_exe_status=200, + do_code_exe=self.do_code_exe + ) + else: + return CodeBoxResponse( + code_exe_type="error", + code_text=code_text, + code_exe_response="Unsupported display type", + code_exe_status=420, + do_code_exe=self.do_code_exe + ) + elif output["output_type"] == "error": + return CodeBoxResponse( + code_exe_type="error", + code_text=code_text, + code_exe_response="error", + code_exe_status=500, + do_code_exe=self.do_code_exe + ) + else: + return CodeBoxResponse( + code_exe_type="error", + code_text=code_text, + code_exe_response=f"Unsupported output encountered: {output}", + code_exe_status=420, + do_code_exe=self.do_code_exe + ) + + contents = self.nbe.run_code_on_notebook(code_text) + content = contents[0] + return _output_parser(content) + + def restart(self, ) -> CodeBoxStatus: + return CodeBoxStatus(status="restared") + + def stop(self, ) -> CodeBoxStatus: + pass + + def __del__(self): + self.stop() \ No newline at end of file diff --git a/muagent/schemas/__init__.py b/muagent/schemas/__init__.py index e69de29..a43edef 100644 --- a/muagent/schemas/__init__.py +++ b/muagent/schemas/__init__.py @@ -0,0 +1,12 @@ +from .message import Message +from .memory import Memory +from .agent_config import PromptConfig, AgentConfig +from .project_config import ProjectConfig, EKGProjectConfig +from .models import LLMConfig, ModelConfig + + +__all__ = [ + "Message", "Memory", + "PromptConfig", "AgentConfig", "LLMConfig", "ModelConfig", + "EKGProjectConfig", "ProjectConfig", +] diff --git a/muagent/schemas/agent_config.py b/muagent/schemas/agent_config.py new file mode 100644 index 0000000..fe1763d --- /dev/null +++ b/muagent/schemas/agent_config.py @@ -0,0 +1,68 @@ + +from pydantic import BaseModel, root_validator +from typing import List, Dict, Optional, Union, Literal + + +class PromptConfig(BaseModel): + """The dataclass for prompt config.""" + + config_name: str = "codefuse" + """The config name of prompt.""" + + prompt_manager_type: str = "CommonPromptManager" + """The type of prompt manager.""" + + language: Literal['en', 'zh'] = 'en' + """The language of prompt manager.""" + + +class AgentConfig(BaseModel): + """The dataclass for agent config""" + + config_name: str + """The name of the agent configuration. It equals to agent name""" + + agent_type: str + """The type of the agent wrapper, which is to identify the agent wrapper + class in model configuration.""" + + agent_name: str + """The name of the agent, which is used in agent api calling. It will eqaul to role name""" + + agent_desc: str = "" + """The role description of this role.""" + + system_prompt: str = "" + """The system prompt of this role.""" + + input_template: Union[str, BaseModel] = "" + """The input template for role.""" + + output_template: Union[str, BaseModel] = "" + """The output template for role.""" + + prompt: str = "" + """The full prompt of this role. it will override system prompt + input prompt + output prompt""" + + tools: List[str] = [] + """The tools' name of this role. it will use these tools to complete task""" + + agents: List[str] = [] + """This role can manage some agents. It will ask one agent to complete task""" + + # + llm_config_name: Optional[str] + """The name of the llm model configuration.""" + + em_config_name: Optional[str] + """The name of the embedding model configuration.""" + + prompt_config_name: Optional[str] + """""" + + @root_validator(pre=True) + def set_default_config_name(cls, values): + """Set config_name to model_name if config_name is not provided.""" + if 'config_name' not in values or values['config_name'] is None: + values['config_name'] = values.get('agent_name') + return values \ No newline at end of file diff --git a/muagent/schemas/apis/ekg_api_schema.py b/muagent/schemas/apis/ekg_api_schema.py index 5c677df..f4d6eef 100644 --- a/muagent/schemas/apis/ekg_api_schema.py +++ b/muagent/schemas/apis/ekg_api_schema.py @@ -1,8 +1,9 @@ from pydantic import BaseModel -from typing import List, Dict, Optional, Literal +from typing import List, Dict, Optional, Literal, Union from enum import Enum from muagent.schemas.common import GNode, GEdge +from muagent.schemas.models import ChatMessage, Choice @@ -41,6 +42,20 @@ class LLMRequest(BaseModel): text: str stop: Optional[str] + +class LLMFCRequest(BaseModel): + messages: List[ChatMessage] + system_prompt: Optional[str] = None + tools: List[Union[str, object]] = [] + tool_choice: Optional[Literal["auto", "required"]] = "auto" + parallel_tool_calls: bool = False + stop: Optional[str] + + +class LLMFCResponse(EKGResponse): + choices: List[Choice] + + class LLMResponse(EKGResponse): successCode: int errorMessage: str @@ -123,7 +138,7 @@ class SearchAncestorRequest(BaseModel): class LLMParamsResponse(BaseModel): url: Optional[str] = None model_name: str - model_type: Literal["openai", "ollama", "lingyiwanwu", "kimi", "moonshot", "qwen"] = "ollama" + model_type: str = "ollama" api_key: str = "" stop: Optional[str] = None temperature: float = 0.3 @@ -137,7 +152,7 @@ class LLMParamsRequest(LLMParamsResponse): class EmbeddingsParamsResponse(BaseModel): # ollama embeddings url: Optional[str] = None - embedding_type: Literal["openai", "ollama"] = "ollama" + embedding_type: str = "ollama" model_name: str = "qwen2.5:0.5b" api_key: str = "" diff --git a/muagent/schemas/common/__init__.py b/muagent/schemas/common/__init__.py index d47cbe9..d633342 100644 --- a/muagent/schemas/common/__init__.py +++ b/muagent/schemas/common/__init__.py @@ -1,8 +1,12 @@ from .auto_extract_graph_schema import * - +from .actions import * +from .log import LogVerboseEnum __all__ = [ "GNodeAbs", "GEdgeAbs", "GRelationAbs", "Attribute", "GNode", "GEdge", "Graph", "GEdgeRequst", "GNodeRequest", "GRelation", - "ThemeEnums", "GbaseExecStatus" + "ThemeEnums", "GbaseExecStatus", + + "ActionStatus", + "LogVerboseEnum", ] \ No newline at end of file diff --git a/muagent/schemas/common/actions.py b/muagent/schemas/common/actions.py new file mode 100644 index 0000000..60d4bb6 --- /dev/null +++ b/muagent/schemas/common/actions.py @@ -0,0 +1,76 @@ +from pydantic import BaseModel +from enum import Enum + + + +class ActionStatus(Enum): + DEFAUILT = "default" + + FINISHED = "finished" + STOPPED = "stopped" + CONTINUED = "continued" + + TOOL_USING = "tool_using" + CODING = "coding" + CODE_EXECUTING = "code_executing" + CODING2FILE = "coding2file" + + PLANNING = "planning" + UNCHANGED = "unchanged" + ADJUSTED = "adjusted" + CODE_RETRIEVAL = "code_retrieval" + + def __eq__(self, other): + if isinstance(other, str): + return self.value.lower() == other.lower() + return super().__eq__(other) + + +class Action(BaseModel): + action_name: str + description: str + +class FinishedAction(Action): + action_name: str = ActionStatus.FINISHED + description: str = "provide the final answer to the original query to break the chain answer" + +class StoppedAction(Action): + action_name: str = ActionStatus.STOPPED + description: str = "provide the final answer to the original query to break the agent answer" + +class ContinuedAction(Action): + action_name: str = ActionStatus.CONTINUED + description: str = "cant't provide the final answer to the original query" + +class ToolUsingAction(Action): + action_name: str = ActionStatus.TOOL_USING + description: str = "proceed with using the specified tool." + +class CodingdAction(Action): + action_name: str = ActionStatus.CODING + description: str = "provide the answer by writing code" + +class Coding2FileAction(Action): + action_name: str = ActionStatus.CODING2FILE + description: str = "provide the answer by writing code and filename" + +class CodeExecutingAction(Action): + action_name: str = ActionStatus.CODE_EXECUTING + description: str = "provide the answer by writing executable code" + +class PlanningAction(Action): + action_name: str = ActionStatus.PLANNING + description: str = "provide a sequence of tasks" + +class UnchangedAction(Action): + action_name: str = ActionStatus.UNCHANGED + description: str = "this PLAN has no problem, just set PLAN_STEP to CURRENT_STEP+1." + +class AdjustedAction(Action): + action_name: str = ActionStatus.ADJUSTED + description: str = "the PLAN is to provide an optimized version of the original plan." + +# extended action exmaple +class CodeRetrievalAction(Action): + action_name: str = ActionStatus.CODE_RETRIEVAL + description: str = "execute the code retrieval to acquire more code information" diff --git a/muagent/schemas/common/log.py b/muagent/schemas/common/log.py new file mode 100644 index 0000000..426815e --- /dev/null +++ b/muagent/schemas/common/log.py @@ -0,0 +1,38 @@ +from enum import Enum +from typing import Union + + +class LogVerboseEnum(Enum): + Log0Level = "0" # don't print log + Log1Level = "1" # print level-1 log + Log2Level = "2" # print level-2 log + Log3Level = "3" # print level-3 log + + def __eq__(self, other): + if isinstance(other, str): + return self.value.lower() == other.lower() + if isinstance(other, LogVerboseEnum): + return self.value == other.value + return False + + def __ge__(self, other): + if isinstance(other, LogVerboseEnum): + return int(self.value) >= int(other.value) + if isinstance(other, str): + return int(self.value) >= int(other) + return NotImplemented + + def __le__(self, other): + if isinstance(other, LogVerboseEnum): + return int(self.value) <= int(other.value) + if isinstance(other, str): + return int(self.value) <= int(other) + return NotImplemented + + @classmethod + def ge(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']): + return enum_value <= other + + @classmethod + def le(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']): + return enum_value <= other \ No newline at end of file diff --git a/muagent/schemas/kb/base_schema.py b/muagent/schemas/kb/base_schema.py index fd14756..e9dda17 100644 --- a/muagent/schemas/kb/base_schema.py +++ b/muagent/schemas/kb/base_schema.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, Integer, String, DateTime, func -from muagent.orm.db import Base +from muagent.db_handler.db import Base class KnowledgeBaseSchema(Base): diff --git a/muagent/schemas/memory.py b/muagent/schemas/memory.py new file mode 100644 index 0000000..7e0ba80 --- /dev/null +++ b/muagent/schemas/memory.py @@ -0,0 +1,193 @@ +from pydantic import BaseModel +from typing import List, Union, Dict, Optional, Literal +from loguru import logger + +from .message import Message + + +class Memory(BaseModel): + '''The base dataclass of Memory''' + + messages: List[Message] = [] + _limit: Optional[int] = None + + def set_limit(self, limit: Optional[int] = None): + self._limit = limit + + def _limit_messages(self, ): + if self._limit: + self.messages = self.messages[-self._limit:] + + def append(self, message: Message): + self.messages.append(message) + self._limit_messages() + + def extend(self, memory: 'Memory'): + self.messages.extend(memory.messages) + self._limit_messages() + + def update(self, message: Message, role_tag: str = None): + if role_tag is None: + return + message_index = message.message_index + idx = None + for idx, msg in enumerate(self.messages): + if msg.session_index == message_index: break + if idx is not None: + if (self.messages[idx].role_tags, list): + self.messages[idx].role_tags = list(set(self.messages[idx].role_tags + [role_tag])) + else: + self.messages[idx].role_tags += f", {role_tag}" + + def sort_by_key(self, key: str): + self.messages = sorted(self.messages, key=lambda x: getattr(x, key, f"No this {key}")) + + def clear(self, k: int = None): + '''save the messages by k limit''' + if k is None: + self.messages = [] + else: + self.messages = self.messages[-k:] + + def get_messages(self, k=0) -> List[Message]: + """Return the most recent k memories, return all when k=0""" + return self.messages[-k:] + + def get_datetimes(self) -> List[any]: + """get datetime values values. default: end_datetime""" + return self.get_memory_values("end_datetime") + + def get_contents(self) -> List[any]: + """get content values""" + return self.get_memory_values("content") + + def get_memory_values(self, key: str) -> List[any]: + return [message.get_value(key) for message in self.messages] + + def split_by_role_type(self) -> List[Dict[str, 'Memory']]: + """ + Split messages into rounds of conversation based on role_type. + Each round consists of consecutive messages of the same role_type. + User messages form a single round, while assistant and function messages are combined into a single round. + Each round is represented by a dict with 'role' and 'memory' keys, with assistant and function messages + labeled as 'assistant'. + """ + rounds = [] + current_memory = Memory() + current_role = None + + for msg in self.messages: + # Determine the message's role, considering 'function' as 'assistant' + message_role = 'assistant' if msg.role_type in ['assistant', 'function'] else 'user' + + # If the current memory is empty or the current message is of the same role_type as current_role, add to current memory + if not current_memory.messages or current_role == message_role: + current_memory.append(msg) + else: + # Finish the current memory and start a new one + rounds.append({'role': current_role, 'memory': current_memory}) + current_memory = Memory() + current_memory.append(msg) + + # Update the current_role, considering 'function' as 'assistant' + current_role = message_role + + # Don't forget to add the last memory if it exists + if current_memory.messages: + rounds.append({'role': current_role, 'memory': current_memory}) + + return rounds + + def format_rounds_to_html(self) -> str: + formatted_html_str = "" + rounds = self.split_by_role_type() + + for round in rounds: + role = round['role'] + memory = round['memory'] + + # 转换当前round的Memory为字符串 + messages_str = memory.to_str_messages() + + # 根据角色类型添加相应的HTML标签 + if role == 'user': + formatted_html_str += f"\n{messages_str}\n\n" + else: # 对于'assistant'和'function'角色,我们将其视为'assistant' + formatted_html_str += f"\n{messages_str}\n\n" + + return formatted_html_str + + def to_format_messages( + self, + attributes: dict[str, Union[any, List[any]]] = {}, + filter_type: Optional[Literal['select', 'filter']] = None, + *, + return_all: bool = True, + content_key: str = "content", + with_tag: bool = False, + format_type: Literal['raw', 'tuple', 'dict', 'str']='raw', + logic: Literal['or', 'and'] = 'and' + ) -> List[Message]: + '''Filter messages by attributes''' + def _logic_check(values: List[bool], logic): + # default: not filter any message + if values == []: return True + return any(values) if logic == "or" else all(values) + + def _select(message, attrs, select_type="filter"): + if select_type == "filter": + return [message.get(key) not in value if isinstance(value, list) else + message.get(key) != value + for key, value in attrs.items() + ] + else: + return [message.get(key) in value if isinstance(value, list) else + message.get(key) == value + for key, value in attrs.items() + ] + # + messages = [ + message for message in self.messages + if _logic_check(_select(message, attributes, filter_type), logic) + ] + + # + if format_type == "tuple": + return [ + message.to_tuple_message(return_all, content_key) + for message in messages + ] + elif format_type == "dict": + return [ + message.to_dict_message() + for message in messages + ] + elif format_type == "str": + return "\n\n".join([ + message.to_str_content(content_key, with_tag=with_tag) + for message in messages + ]) + + return messages + + @classmethod + def from_memory_list(cls, memorys: List['Memory']) -> 'Memory': + return cls(messages=[message for memory in memorys for message in memory.get_messages()]) + + def __len__(self, ): + return len(self.messages) + + def __str__(self) -> str: + return self.to_format_messages(format_type="str") + return "\n".join([": ".join(i) for i in self.to_format_messages(format_type="tuple")]) + + def __add__(self, other: Union[Message, 'Memory']) -> 'Memory': + if isinstance(other, Message): + return Memory(messages=self.messages + [other]) + elif isinstance(other, Memory): + return Memory(messages=self.messages + other.messages) + else: + raise ValueError(f"cant add unspecified type like as {type(other)}") + + + \ No newline at end of file diff --git a/muagent/schemas/message.py b/muagent/schemas/message.py new file mode 100644 index 0000000..a444a58 --- /dev/null +++ b/muagent/schemas/message.py @@ -0,0 +1,258 @@ +from pydantic import BaseModel, root_validator +from typing import List, Dict, Optional, Literal, Union, Sequence, Tuple +from loguru import logger +import uuid +from muagent.utils.common_utils import getCurrentDatetime + + + +class Message(BaseModel): + '''The base dataclass of Message + + The following is an example: + + .. code-block:: python + + from muagent.schemas.message import Message + msg = Message( + role_name="system", + role_type="system", + content="You're a helpful assistant", + ) + ''' + + # + role_name: str = "muagent" + '''The role name of agent to generate this message.''' + + role_type: Literal[ + 'system', + 'user', + 'assistant', + 'observation', + 'tool_call', + 'function', + 'codefuse', + 'summary' + ] = "codefuse" + '''The role type of agent to generate this message. such as system/user/assistant/observation/tool_call''' + # + role_tags: Union[Sequence[str], str] = '' + '''The tags of this message.''' + + embedding: Optional[Sequence] = None + '''The embedding from LLM of this message.''' + + image_urls: Optional[Sequence[str]] = None + '''The image_urls from LLM of this message.''' + + action_status: str = "default" + '''llm\tool\code executre information''' + + content: Optional[str] = "" + '''The last response from LLM of this message.''' + + step_content: Optional[str] = '' + '''The multi content from LLM of this message, connected by \n''' + + parsed_content: Dict = {} + '''The structed content from LLM parsing of this message''' + + parsed_contents: List[Dict] = [] + '''The multi structed content from LLM parsing of this message''' + + spec_parsed_content: Dict = {} + '''The special structed content from LLM parsing of this message''' + + spec_parsed_contents: List[Dict] = [] + '''The multi special structed content from LLM parsing of this message''' + + global_kwargs: Dict = {} + '''user's customed kargs for init or end action''' + + # input from last message + input_text: Optional[str] = "" + '''The input text from last message.''' + + parsed_input: Dict = {} + '''The structed input from LLM parsing from last message''' + + parsed_inputs: List[Dict] = [] + '''The multi structed input from LLM parsing from last message''' + + spec_parsed_input: Dict = {} + '''The special structed content from LLM parsing of this message''' + + spec_parsed_inputs: List[Dict] = [] + '''The multi special structed content from LLM parsing of this message''' + + # + session_index: Optional[str] = None + '''The session index of this message.''' + + message_index: Optional[str] = None + '''The message index of this message.''' + + node_index: Optional[str] = "default" + '''The node index of this message.''' + + # + start_datetime: str = None + '''The first record time of this message.''' + + end_datetime: str = None + '''The last update time of this message.''' + + datetime_format: str = "%Y-%m-%d %H:%M:%S.%f" + + @root_validator(pre=True) + def check_card_number_omitted(cls, values): + input_text = values.get("input_text") + content = values.get("content") + if content is None: + values["content"] = content or input_text + return values + + @root_validator(pre=True) + def check_datetime(cls, values): + start_datetime = values.get("start_datetime") + end_datetime = values.get("end_datetime") + datetime_format = values.get("datetime_format", "%Y-%m-%d %H:%M:%S.%f") + if start_datetime is None: + values["start_datetime"] = getCurrentDatetime(datetime_format) + if end_datetime is None: + values["end_datetime"] = getCurrentDatetime(datetime_format) + return values + + @root_validator(pre=True) + def check_message_index(cls, values): + message_index = values.get("message_index") + session_index = values.get("session_index") + if message_index is None or message_index == "": + values["message_index"] = str(uuid.uuid4()).replace("-", "_") + + if session_index is None or session_index == "": + values["session_index"] = str(uuid.uuid4()).replace("-", "_") + return values + + def update_input(self, input: Union[str, 'Message'], parsed_input: Dict = {}): + if isinstance(input, str): + self.update_attributes({"input_text": input}) + else: + self.update_attributes({"input_text": input.content}) + + def update_parsed_input(self, parsed_input: Dict): + self.update_attributes({"parsed_input": parsed_input}) + self.update_attributes({"parsed_inputs": self.parsed_inputs + [parsed_input]}) + + def update_spec_parsed_input(self, spec_parsed_input: Dict): + self.update_attributes({"spec_parsed_input": spec_parsed_input}) + self.update_attributes({"spec_parsed_inputs": self.spec_parsed_inputs + [spec_parsed_input]}) + + def update_content(self, content: Union[str, 'Message'], parsed_content: Dict = {}): + if isinstance(content, str): + self.update_attributes({"content": content}) + self.update_attributes({"step_content": self.step_content + f"\n{content}"}) + else: + self.update_attributes({"content": content.content}) + self.update_attributes({"step_content": self.step_content + f"\n{content.content}"}) + + def update_parsed_content(self, parsed_content: Dict = {}): + self.update_attributes({"parsed_content": parsed_content}) + self.update_attributes({"parsed_contents": self.parsed_contents + [parsed_content]}) + + def update_spec_parsed_content(self, spec_parsed_content: Dict = {}): + self.update_attributes({"spec_parsed_content": spec_parsed_content}) + self.update_attributes({"spec_parsed_contents": self.spec_parsed_contents + [spec_parsed_content]}) + + def update_attributes(self, attributes: dict): + '''update message attributes''' + for k, v in attributes.items(): + self.update_attribute(k, v) + + def update_attribute(self, key: str, value): + if hasattr(self, key): + setattr(self, key, value) + self.end_datetime = getCurrentDatetime(self.datetime_format) + else: + raise AttributeError(f"{key} is not a valid property of {self.__class__.__name__}") + + def to_dict_message(self, ) -> Dict: + return vars(self) + + def to_tuple_message( + self, + return_all: bool = True, + content_key: Literal[ + 'input_text', + 'content', + 'step_conetent', + 'parsed_content', + 'spec_parsed_contents', + ] = "content", + ) -> Union[str, Tuple[str, str]]: + content = self.to_str_content(False, content_key) + if return_all: + return (self.role_name, content) + else: + return (content) + + def to_str_content( + self, + content_key: Literal[ + 'input_text', + 'content', + 'step_conetent', + 'parsed_content', + 'parsed_contents', + 'spec_parsed_content', + 'spec_parsed_contents', + ] = "content", + with_tag=False + ) -> str: + # TODO while role_type is USER return input_query, else return role_content + response = self.content or self.input_text + if content_key == "content": + content = response + elif content_key == "input_text": + content = self.input_text + elif content_key == "step_content": + content = self.step_content or response + elif content_key == "parsed_content": + content = "\n".join([v for k, v in self.parsed_content.items()]) or response + # content = "\n".join([f"**{k}:** {v}" for k, v in self.parsed_content.items()]) or response + elif content_key == "spec_parsed_content": + content = "\n".join([f"**{k}:** {v}" for k, v in self.spec_parsed_content.items()]) or response + elif content_key == "parsed_contents": + content = "\n".join([v for po in self.parsed_contents for k,v in po.items()]) or response + elif content_key == "spec_parsed_contents": + content = "\n".join([f"**{k}:** {v}" for po in self.spec_parsed_contents for k,v in po.items()]) or response + else: + content = response + + if with_tag: + start_tag = f"<{self.role_type}-{self.role_name}-message>" + end_tag = f"" + return f"{start_tag}\n{content}\n{end_tag}" + else: + return content + + def get_value(self, key: str) -> any: + """ + Get the value of the given key from the message. + + :param key: The key of the attribute to retrieve. + :return: The value associated with the key. + """ + if hasattr(self, key): + return getattr(self, key, None) + raise AttributeError(f"Message don't have attribute {key}") + + def get_attribute_type(self, key): + return type(getattr(self, key, None)) + + def __str__(self) -> str: + # key_str = '\n'.join([k for k, v in vars(self).items()]) + # logger.debug(f"{key_str}") + return "\n".join([": ".join([k, str(v)]) for k, v in vars(self).items()]) + \ No newline at end of file diff --git a/muagent/schemas/models/__init__.py b/muagent/schemas/models/__init__.py new file mode 100644 index 0000000..046c4b7 --- /dev/null +++ b/muagent/schemas/models/__init__.py @@ -0,0 +1,11 @@ +from .model import ModelConfig, LLMConfig +from .llm_shemas import * + + +__all__ = [ + "ModelConfig", "LLMConfig" + + "ChatMessage", "FunctionCallData", "ToolCall", "LLMOuputMessage", + "Choice", "UsageData", "LLMResponse", + +] \ No newline at end of file diff --git a/muagent/schemas/models/llm_shemas.py b/muagent/schemas/models/llm_shemas.py new file mode 100644 index 0000000..45f54e9 --- /dev/null +++ b/muagent/schemas/models/llm_shemas.py @@ -0,0 +1,50 @@ +from pydantic import BaseModel, Field +from typing import List, Dict, Optional, Union +from enum import Enum + + + +class ChatMessage(BaseModel): + role: str + content: str + + +class FunctionCallData(BaseModel): + name: str + arguments: Union[str, dict] + + +class ToolCall(BaseModel): + id: Optional[Union[str, int]] = None + type: str = "function" + function: FunctionCallData + + +class LLMOuputMessage(BaseModel): + content: Optional[str] = None + role: str + tool_calls: List[ToolCall] = [] + + +class Choice(BaseModel): + finish_reason: str + index: int = 0 + message: LLMOuputMessage + + +class UsageData(BaseModel): + completion_tokens: int + prompt_tokens: int + total_token: int + + +class LLMResponse(BaseModel): + choices: List[Choice] + created: int = 0 + id: str + model: str + object: str + usage: Optional[UsageData] = None + + + diff --git a/muagent/schemas/models/model.py b/muagent/schemas/models/model.py new file mode 100644 index 0000000..4e4a4fb --- /dev/null +++ b/muagent/schemas/models/model.py @@ -0,0 +1,56 @@ + + +from pydantic import BaseModel, root_validator +from typing import List, Dict, Optional, Union, Literal + + + +class ModelConfig(BaseModel): + """The dataclass for model config.""" + + config_name: Optional[str] = None + """The name of the model configuration. It equals to model_name or model_type.""" + + model_type: str + """The type of the model wrapper, which is to identify the model wrapper + class in model configuration.""" + + model_name: str + """The name of the model, which is used in model api calling.""" + + api_key: Optional[str] = None + """The api key of the model, which is used in model api calling.""" + + api_url: Optional[str] = None + """The api url of the model, which is used in model api calling.""" + + max_tokens: Optional[int] = None + """The max_tokens of the model, which is used in model api calling.""" + + top_p: float = 0.9 + """The top_p of the model, which is used in model api calling.""" + + temperature: float = 0.3 + """The temperature of the model, which is used in model api calling.""" + + stream: bool = False + """The stream mode of the model, which is used in model api calling.""" + + @root_validator(pre=True) + def set_default_config_name(cls, values): + """Set config_name to model_name if config_name is not provided.""" + if 'config_name' not in values or values['config_name'] is None: + values['config_name'] = values.get('model_name') + return values + + + +class LLMConfig(BaseModel): + """temp config will delete""" + model_name: str = "gpt-3.5-turbo" + model_engine: str = "openai" + temperature: float = 0.3 + stop: Union[List[str], str] = None + api_key: str = "" + api_base_url: str = "" + llm: Optional[str] = "" \ No newline at end of file diff --git a/muagent/schemas/project_config.py b/muagent/schemas/project_config.py new file mode 100644 index 0000000..ff19c19 --- /dev/null +++ b/muagent/schemas/project_config.py @@ -0,0 +1,114 @@ + +from pydantic import BaseModel, Field +from typing import ( + List, + Dict, + Optional, + Union, + Literal, + Any +) + +from .models import ModelConfig, LLMConfig +from .agent_config import AgentConfig, PromptConfig +from .db import GBConfig, TBConfig + + + +class ProjectConfig(BaseModel): + """The dataclass of project config""" + + agent_configs: Optional[Dict[str, AgentConfig]] + """""" + + prompt_configs: Optional[Dict[str, PromptConfig]] + """""" + + model_configs: Optional[Dict[str, Any]] + """""" + + graph: Any = None + """""" + + def extend_agent_configs( + self, + agent_configs: Union[AgentConfig, List[AgentConfig], Dict[str, AgentConfig]] + ): + + if isinstance(agent_configs, AgentConfig): + self.agent_configs.update({agent_configs.config_name: agent_configs}) + + if isinstance(agent_configs, List): + self.agent_configs.update({ + i.config_name: agent_configs for i in agent_configs + if isinstance(agent_configs, AgentConfig) + }) + elif isinstance(agent_configs, Dict): + self.agent_configs.update(agent_configs) + + def extend_prompt_configs( + self, + prompt_configs: Union[PromptConfig, List[PromptConfig], Dict[str, PromptConfig]] + ): + if isinstance(prompt_configs, PromptConfig): + self.prompt_configs.update({prompt_configs.config_name: prompt_configs}) + + if isinstance(prompt_configs, List): + self.prompt_configs.update({ + i.config_name: prompt_configs for i in prompt_configs + if isinstance(prompt_configs, PromptConfig) + }) + elif isinstance(prompt_configs, Dict): + self.prompt_configs.update(prompt_configs) + + def extend_model_configs( + self, + model_configs: Union[ModelConfig, List[ModelConfig], Dict[str, ModelConfig]] + ): + if isinstance(model_configs, ModelConfig): + self.model_configs.update({model_configs.config_name: model_configs}) + + if isinstance(model_configs, List): + self.model_configs.update({ + i.config_name: model_configs for i in model_configs + if isinstance(model_configs, ModelConfig) + }) + elif isinstance(model_configs, Dict): + self.model_configs.update(model_configs) + + def extend_graph(self, graph): + """wait""" + pass + + def __add__(self, other: 'ProjectConfig') -> 'ProjectConfig': + if isinstance(other, ProjectConfig): + self.extend_agent_configs(other.agent_configs) + self.extend_prompt_configs(other.model_configs) + self.extend_prompt_configs(other.prompt_configs) + self.extend_graph(other.graph) + return self + else: + raise ValueError(f"cant add unspecified type like as {type(other)}") + + + +class EKGProjectConfig(BaseModel): + """The dataclass of project config""" + + config_name: str = "default" + """The config name of EKG Project""" + + model_configs: Optional[Dict[str, Union[ModelConfig, Any]]] + """""" + + embed_configs: Optional[Dict[str, ModelConfig]] + """""" + + agent_configs: Optional[Dict[str, AgentConfig]] + """""" + + prompt_configs: Optional[Dict[str, PromptConfig]] + """""" + + db_configs: Optional[Dict[str, Union[GBConfig, TBConfig]]] + """""" \ No newline at end of file diff --git a/muagent/service/ekg_construct/ekg_construct_base.py b/muagent/service/ekg_construct/ekg_construct_base.py index f13466e..be19262 100644 --- a/muagent/service/ekg_construct/ekg_construct_base.py +++ b/muagent/service/ekg_construct/ekg_construct_base.py @@ -22,7 +22,8 @@ from muagent.schemas.db import * from muagent.schemas.common import * from muagent.db_handler import * -from muagent.orm import table_init +# from muagent.orm import table_init +from muagent.db_handler import table_init from muagent.base_configs.env_config import EXTRA_KEYWORDS_PATH from muagent.connector.configs.generate_prompt import * @@ -193,6 +194,12 @@ def init_gb(self, do_init: bool=None): self.create_gb_tags_and_edgetypes() self.waiting_tags_edgetypes_initialize() + # print('Node Tags和Edge Types初始化中,等待20秒......') + # time.sleep(20) + else: + self.gb.add_hosts('storaged0', 9779) + # 创建node tags和edge types + self.create_gb_tags_and_edgetypes() else: self.gb = None @@ -323,6 +330,28 @@ def _dfs(node, current_path: List): def create_gb_tags_and_edgetypes(self): + + def _check(): + node_types = [i for i in TYPE2SCHEMA.keys() if i!='edge'] + tags = self.gb.show_tags() + tag_names = [tag["Name"] for tag in tags] + tag_flag = set(tag_names) == set(node_types) + + edges = self.gb.show_edge_type() + edge_names = [edge["Name"] for edge in edges] + inset_edges = [f"{i}{k}{j}" + for i in node_types + for j in node_types + for k in ["_route_", "_extend_", "_conclude_"] + ] + edge_flag = set(edge_names) == set(inset_edges) + logger.info(f"tag_flag={tag_flag}, edge_flag={edge_flag}") + # + return tag_flag and edge_flag + + # if tags is existed and edge is existed, return + if _check(): return + # 节点标签和属性 (done) for node_type, schema in TYPE2SCHEMA.items(): if node_type == 'edge': @@ -364,7 +393,7 @@ def create_gb_tags_and_edgetypes(self): # 边类型(名称) node_types = list(TYPE2SCHEMA.keys()) - logger.info(node_types) + node_types = [i for i in TYPE2SCHEMA.keys() if i!='edge'] for i in range(len(node_types)): for j in range(len(node_types)): if node_types[i] != 'edge' and node_types[j] != 'edge': # 排除 node_type 为 'edge' @@ -374,9 +403,11 @@ def create_gb_tags_and_edgetypes(self): self.gb.create_edge_type(edge_type2, edge_attributes_dict) edge_type3 = f"{node_types[i]}_conclude_{node_types[j]}" self.gb.create_edge_type(edge_type3, edge_attributes_dict) - - + time.sleep(5) + while not _check(): + logger.info('Node Tags和Edge Types初始化中,等待5秒......') + time.sleep(5) def update_graph( self, @@ -874,6 +905,7 @@ def get_node_by_id( ) -> GNode: if service_type=="gbase": node = self.gb.get_current_node({'id': nodeid}, node_type=node_type) + if node is None: return node node = self._normalized_nodes_type(nodes=[node])[0] else: node = GNode(id=nodeid, type="", attributes={}) diff --git a/muagent/service/ekg_inference/intention_match_rule.py b/muagent/service/ekg_inference/intention_match_rule.py index f94b6ab..19a821c 100644 --- a/muagent/service/ekg_inference/intention_match_rule.py +++ b/muagent/service/ekg_inference/intention_match_rule.py @@ -1,5 +1,5 @@ import re -import Levenshtein +import edit_distance as ed from muagent.schemas.common import GNode @@ -13,13 +13,14 @@ def edit_distance(cls, node: GNode, pattern=None, **kwargs): desc: str = node.attributes.get('description', '') if pattern is None: - return -Levenshtein.distance(desc, s) + return -ed.edit_distance(desc, s)[0] desc_list = re.findall(pattern, desc) if not desc_list: return -float('inf') - return max([-Levenshtein.distance(x, s) for x in desc_list]) + return max([-ed.edit_distance(x, s)[0] for x in desc_list]) + @classmethod def edit_distance_integer(cls, node: GNode, **kwargs): diff --git a/muagent/service/ekg_inference/intention_router.py b/muagent/service/ekg_inference/intention_router.py index 2bb5136..f41d063 100644 --- a/muagent/service/ekg_inference/intention_router.py +++ b/muagent/service/ekg_inference/intention_router.py @@ -103,8 +103,7 @@ def _func(node: GNode, rule: Callable): return select_node, error_msg def get_intention_by_node_info_match( - self, root_node_id: str, filter_attribute: Optional[dict] = None, - gb_handler: Optional[GBHandler] = None, + self, root_node_id: str, gb_handler: Optional[GBHandler] = None, rule: Union[Rule_type, list[Rule_type]] = None, **kwargs ) -> dict[str, Any]: gb_handler = gb_handler if gb_handler is not None else self.gb_handler @@ -124,10 +123,7 @@ def get_intention_by_node_info_match( RuleRetInfo(node_id=root_node_id, error_msg=error_msg, status=RouterStatus.OTHERS.value)) if not (root_node_id and self._node_exist(root_node_id, gb_handler)): - if not root_node_id: - error_msg = f'No node matches attribute {filter_attribute}.' - else: - error_msg = f'Node(id={root_node_id}, type={self._node_type}) does not exist!' + error_msg = f'Node(id={root_node_id}, type={self._node_type}) does not exist!' return asdict( RuleRetInfo(node_id=root_node_id, error_msg=error_msg, status=RouterStatus.OTHERS.value)) @@ -284,6 +280,8 @@ def get_intention_whether_execute(self, query: str, agent=None) -> bool: return False def get_intention_consult_which(self, query: str, agent=None, root_node_id: Optional[str]=None) -> str: + if isinstance(query, (list, tuple)): + query = query[0] agent = agent if agent else self.agent query_consult_which = itp.CONSULT_WHICH_PROMPT.format(query=query) ans = agent.predict(query_consult_which) @@ -386,13 +384,12 @@ def _dfs(s: str, ancestor: str, path: str, out: dict, visited: set): if child in nodes: if ancestor in out: out.pop(ancestor) + visited.add(ancestor) temp_ancestor = child else: temp_ancestor = ancestor child_path = split.join((path, child)) _dfs(child, temp_ancestor, child_path, out, visited) - if s in nodes: - visited.add(s) if len(nodes) == 0: return dict() diff --git a/muagent/service/ui_file_service/code_base_cds.py b/muagent/service/ui_file_service/code_base_cds.py index f1d0e31..43cc0da 100644 --- a/muagent/service/ui_file_service/code_base_cds.py +++ b/muagent/service/ui_file_service/code_base_cds.py @@ -6,7 +6,7 @@ @desc: ''' from loguru import logger -from muagent.orm.db import with_session +from muagent.db_handler.db import with_session from muagent.schemas.kb.base_schema import CodeBaseSchema diff --git a/muagent/service/ui_file_service/document_base_cds.py b/muagent/service/ui_file_service/document_base_cds.py index 43f8700..c8bc73a 100644 --- a/muagent/service/ui_file_service/document_base_cds.py +++ b/muagent/service/ui_file_service/document_base_cds.py @@ -1,4 +1,4 @@ -from muagent.orm.db import with_session +from muagent.db_handler.db import with_session from muagent.schemas.kb.base_schema import KnowledgeBaseSchema diff --git a/muagent/service/ui_file_service/document_file_cds.py b/muagent/service/ui_file_service/document_file_cds.py index 9cdd27e..801ff39 100644 --- a/muagent/service/ui_file_service/document_file_cds.py +++ b/muagent/service/ui_file_service/document_file_cds.py @@ -1,4 +1,4 @@ -from muagent.orm.db import with_session +from muagent.db_handler.db import with_session from muagent.schemas.kb.base_schema import KnowledgeFileSchema, KnowledgeBaseSchema from muagent.schemas.kb.file_schema import DocumentFile diff --git a/muagent/service/utils.py b/muagent/service/utils.py index ae8f7e9..2a2592f 100644 --- a/muagent/service/utils.py +++ b/muagent/service/utils.py @@ -33,10 +33,10 @@ def decode_biznodes( **{**{"id": node.id, "type": node.type}, **node.attributes} ) - if node.type == "opsgptkg_task": - logger.debug(f"schema:{ schema}") - logger.debug(f"node_data:{ type(node_data)}") - logger.debug(f"node_data:{ node_data}") + # if node.type == "opsgptkg_task": + # logger.debug(f"schema:{ schema}") + # logger.debug(f"node_data:{ type(node_data)}") + # logger.debug(f"node_data:{ node_data}") node_data = { k:v @@ -44,8 +44,8 @@ def decode_biznodes( if k not in ["type", "ID", "id", "extra"] } - if node.type == "opsgptkg_task": - logger.debug(f"node_data:{ node_data}") + # if node.type == "opsgptkg_task": + # logger.debug(f"node_data:{ node_data}") # update agent/tool nodes and edges agents = node_data.pop("agents", []) @@ -70,9 +70,9 @@ def decode_biznodes( attributes={} )) - if node.type == "opsgptkg_task": - logger.debug(f"node_data:{ node_data}") - logger.debug(f"node.attributes:{ node.attributes}") + # if node.type == "opsgptkg_task": + # logger.debug(f"node_data:{ node_data}") + # logger.debug(f"node.attributes:{ node.attributes}") new_nodes.append(GNode(**{ "id": node.id, diff --git a/muagent/tools/__init__.py b/muagent/tools/__init__.py index 5200c04..5a6a5c3 100644 --- a/muagent/tools/__init__.py +++ b/muagent/tools/__init__.py @@ -12,12 +12,14 @@ from .ocr_tool import BaiduOcrTool from .stock_tool import StockInfo, StockName from .codechat_tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code +from .undercover import * +from .werewolf import * IMPORT_TOOL = [ WeatherInfo, DistrictInfo, Multiplier, WorldTimeGetTimezoneByArea, KSigmaDetector, MetricsQuery, DDGSTool, DocRetrieval, CodeRetrieval, - BaiduOcrTool, StockInfo, StockName, CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code + BaiduOcrTool, StockInfo, StockName, CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code, ] TOOL_SETS = [tool.__name__ for tool in IMPORT_TOOL] @@ -29,3 +31,6 @@ "toLangchainTools", "get_tool_schema", "tool_sets", "BaseToolModel" ] + TOOL_SETS + +def get_tool(tool_name: str) -> BaseToolModel: + return BaseToolModel._from_name(tool_name) \ No newline at end of file diff --git a/muagent/tools/base_tool.py b/muagent/tools/base_tool.py index 507822a..814a210 100644 --- a/muagent/tools/base_tool.py +++ b/muagent/tools/base_tool.py @@ -1,16 +1,109 @@ +from abc import ABCMeta + from langchain.agents import Tool from langchain.tools import StructuredTool from langchain.tools.base import ToolException from pydantic import BaseModel, Field -from typing import List, Dict -# import jsonref -import json - +from typing import List, Dict, Any, Type +try: + import jsonref +except: + pass -class BaseToolModel: +import json +import copy + + +def simplify_schema(schema: Dict[str, Any], definitions, no_required=False,depth=0) -> Dict[str, Any]: + """简化 schema,去除 $ref 引用和 definitions""" + if definitions is None: return schema + schema_new = copy.deepcopy(schema) + # 去掉 title 字段 + schema_new.pop('title', None) + # 遍历 properties + if 'properties' in schema: + for key, value in schema['properties'].items(): + for k,v in value.items(): + + if k == "allOf": + ref_model_name = v[0]['$ref'].split('/')[-1] # 提取模型名称 + ref_model_value = simplify_schema(definitions[ref_model_name], definitions, no_required=True, depth=depth+1) + schema_new["properties"][key].pop(k) + schema_new["properties"][key].update(ref_model_value) + + if isinstance(v, dict) and '$ref' in v: + ref_model_name = v['$ref'].split('/')[-1] # 提取模型名称 + ref_model_value = simplify_schema(definitions[ref_model_name], definitions, no_required=True, depth=depth+1) + schema_new["properties"][key][k] = ref_model_value + + schema_new["properties"][key].pop("title") + # 去掉 definitions 部分 + if no_required: + schema_new.pop('required', None) + schema_new.pop('definitions', None) + + return schema_new + + +class _ToolWrapperMeta(ABCMeta): + """A meta call to replace the tool wrapper's run function with + wrapper about error handling.""" + + def __new__(mcs, name: Any, bases: Any, attrs: Any) -> Any: + if "__call__" in attrs: + attrs["__call__"] = attrs["__call__"] + return super().__new__(mcs, name, bases, attrs) + + def __init__(cls, name: Any, bases: Any, attrs: Any) -> None: + if not hasattr(cls, "_registry"): + cls._registry = {} # class name + cls._toolname_registry = {} # class attribute name + else: + cls._registry[name] = cls + cls._toolname_registry[cls.name] = cls + super().__init__(name, bases, attrs) + + +class BaseToolModel(metaclass=_ToolWrapperMeta): name = "BaseToolModel" description = "Tool Description" + @classmethod + def _from_name(cls, tool_name: str) -> 'BaseToolModel': + + """Get the specific model wrapper""" + if tool_name in cls._registry: + return cls._registry[tool_name]() # type: ignore[return-value] + elif tool_name in cls._toolname_registry: + return cls._toolname_registry[tool_name]() # type: ignore[return-value] + else: + raise KeyError( + f"Tool Library is missiong" + f" {tool_name}, please check your tool name" + ) + + @classmethod + def intput_to_json_schema(cls) -> Dict[str, Any]: + '''Transform schema to json structure''' + try: + return jsonref.loads(cls.ToolInputArgs.schema_json()) + except: + return simplify_schema( + cls.ToolInputArgs.schema(), + cls.ToolInputArgs.schema().get("definitions") + ) + + @classmethod + def output_to_json_schema(cls) -> Dict[str, Any]: + '''Transform schema to json structure''' + try: + return jsonref.loads(cls.ToolInputArgs.schema_json()) + except: + return simplify_schema( + cls.ToolOutputArgs.schema(), + cls.ToolOutputArgs.schema().get("definitions") + ) + class ToolInputArgs(BaseModel): """ Input for MoveFileTool. @@ -32,7 +125,7 @@ class ToolOutputArgs(BaseModel): key2: str = Field(..., description="hello world!!") @classmethod - def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs: + def run(cls) -> ToolOutputArgs: """excute your tool!""" pass diff --git a/muagent/tools/metrics_query.py b/muagent/tools/metrics_query.py index c3336c7..6198663 100644 --- a/muagent/tools/metrics_query.py +++ b/muagent/tools/metrics_query.py @@ -25,7 +25,8 @@ class ToolOutputArgs(BaseModel): datas: List[float] = Field(..., description="监控时序数组") - def run(machine_ip, time): + @classmethod + def run(cls, machine_ip, time): """excute your tool!""" data = [0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, diff --git a/muagent/tools/undercover.py b/muagent/tools/undercover.py new file mode 100644 index 0000000..6746bf6 --- /dev/null +++ b/muagent/tools/undercover.py @@ -0,0 +1,315 @@ +import os +from typing import ( + List, + Dict +) +from loguru import logger +from pydantic import BaseModel, Field +import random + +from .base_tool import BaseToolModel +from ..models import get_model, ModelConfig + + +class SeatAssignerTool(BaseToolModel): + """ + This tool assigns seat positions to players and formats them in a markdown table. + Example Output: + ``` + | 座位 | 玩家 | + |---|---| + | 1 | **张伟** | + | 2 | **李静** | + | 3 | **王鹏** | + | 4 | **人类玩家** | + ``` + """ + name: str = "谁是卧底-座位分配" + description: str = "谁是卧底的座位分配工具,可以将玩家顺序打乱随机分配座位" + + class ToolInputArgs(BaseModel): + """Input for SeatAssigner.""" + pass # No specific parameters required for this tool + + class ToolOutputArgs(BaseModel): + """Output for SeatAssigner.""" + table: str = Field(..., description="Markdown table of seating arrangement") + + @classmethod + def run(cls, **kwargs) -> str: + """Execute the seat assignment tool.""" + players = [["张伟", "agent_张伟"], ["李静", "agent_李静"], ["王鹏", "agent_王鹏"], ["人类玩家", "agent_人类玩家"]] + # Shuffle players to assign them to random seats + random.shuffle(players) + # Create the markdown table + markdown_table = "\n\n| 座位 | 玩家 |\n|---|---|\n" + "\n".join( + f"| {i+1} | **{players[i][0]}** |" for i in range(len(players)) + ) + return markdown_table + + +class RoleAssignerTool(BaseToolModel): + """ + This class assigns roles and words to players in a game. + The output will include player names, agent names, agent descriptions, and secret words based on their role type. + """ + name: str = "谁是卧底-角色分配" + description: str = "谁是卧底的角色分配工具,可以为每一位玩家分配一个单词和人物角色。" + + class ToolInputArgs(BaseModel): + """Input for assigning roles.""" + pass + + class ToolOutputArgs(BaseModel): + """Output for assigned roles.""" + roles: List[Dict[str, str]] = Field(..., description="List of roles assigned to players") + + @classmethod + def run(cls, **kwargs) -> List[Dict[str, str]]: + words = [ + ["苹果", "梨"], + ["猫", "狗"], + ["摩托车", "自行车"], + ["太阳", "月亮"], + ["红色", "粉色"], + ["大象", "长颈鹿"], + ["铅笔", "钢笔"], + ["牛奶", "豆浆"], + ["河", "湖"], + ["面包", "蛋糕"], + ["饺子", "包子"], + ["冬天", "夏天"], + ["电视", "电脑"], + ["铅笔", "橡皮"], + ["跑步", "游泳"], + ["手机", "平板"], + ["鱼", "虾"], + ["空调", "风扇"], + ["马", "驴"], + ["书", "杂志"], + ["草", "树"], + ["杯子", "碗"], + ["米饭", "面条"], + ["饼干", "蛋糕"], + ["雨伞", "雨衣"], + ["猪", "牛"], + ["白菜", "生菜"], + ["吉他", "钢琴"], + ["飞机", "火车"], + ["镜子", "眼镜"] + ] + + player_names = ["张伟", "李静", "王鹏", "人类玩家"] + roles = ["平民_1", "平民_2", "平民_3", "卧底_1"] + random.shuffle(player_names) + random.shuffle(roles) + + word_idx = random.randint(0, len(words) - 1) + under_cover_word_idx = random.randint(0, 1) + + result = [] + for i in range(len(player_names)): + r = { + "player_name": player_names[i], + "agent_name": f"agent_{player_names[i]}", + "agent_description": roles[i], + "单词": words[word_idx][1 - under_cover_word_idx] if roles[i].startswith("平民") else words[word_idx][under_cover_word_idx] + } + result.append(r) + + return result + + +class GameActionTool(BaseToolModel): + name = "谁是卧底-游戏行动" + description = "谁是卧底的游戏行动工具,需要根据记忆的上下文信息,当前任务、以及你拿到的单词信息来进行回答响应。" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + content: str + + @classmethod + def run(cls, **kwargs) -> str: + """Execute your tool!""" + + template = ( + '##背景##\n' + '您正在参加“谁是卧底”这款游戏,您的目标是:想办法击杀与自己身份不同的所有玩家,获得胜利。\n' + '\n' + '##游戏介绍##\n' + '在“谁是卧底”游戏中,每位玩家会被分配一个[单词](玩家可见)和一个身份(玩家不可见,包括[平民]和[卧底]两种身份),卧底的[单词]跟[平民]不同,但有许多共同的特征。\n' + '\n' + '##任务##\n' + '1. 根据**游戏进展中主持人的最新通知**,感知当前的任务:讨论 or 票选卧底,准备发言。\n' + '2. 如果任务是讨论,感知分配给您的[单词],描述它的某一特征(**描述内容可真可假,禁止描述已经提到过的特征**),您的目标是:让其他玩家相信该特征与他们的[单词]是相符的;否则,投票给某个当前存活玩家,并说明理由,您的目标是:让其他玩家相信,该玩家给出的特征与大家的[单词]都不符。\n' + '\n' + '##发言示例##\n' + '(任务是讨论)一种植物,可食用。\n' + '(任务是票选卧底)我投票给李静,因为对比所有人的发言,他的描述和其他的有明显区别。\n' + '\n' + '##游戏进展##\n' + '{memory}\n' + '\n' + '##注意##\n' + '- 无论您的任务是什么,**禁止泄露自己的[单词],发言内容尽可能简洁!!!**。\n' + '- 如果您的任务是讨论,**描述的特征可真可假,但要避免已经提到过的特征**;如果是票选卧底,**一定要明确表示投票给哪一位玩家(禁止给自己或已经死亡的玩家投票)**。\n' + '- 禁止描述任何没有发生过的事情。\n' + '\n' + '##游戏经验##\n' + '如果任务是讨论,以下是描述[单词]特征时的一些经验:\n' + '1. 保持模糊性:特征不宜过于明显(尤其当您是首位发言的玩家时),这样很容易别人推测出自己的[单词]。\n' + '2. 逐渐清晰:与其他玩家给出的特征相比,您的特征应该更清晰,否则很容易被其他玩家怀疑。\n' + '3. 定位身份:如果您发现多个玩家的特征跟您的[单词]都不符,那么自己的身份很可能是[卧底],应该推测他们的[单词]是什么,**编造**跟他们[单词]相符的特征。\n' + '\n' + '##输出##\n' + 'Python可直接解析的jsonstr,格式如下:\n' + '{{\"thought\": 感知自己的名字、位置(根据主持人的【身份通知】!!!注意您不是人类玩家)、[单词]、当前任务、哪些特征已经被提出来、推测其他玩家的[单词]是什么、自己是否是[卧底]、如何保护自己,分析内容不超过120字, \"output\": 您的发言(避免泄露[单词],避免投票给自己,避免重复特征,直接说出符合您的身份的话,不要输出其他信息)}}\n' + '以{{开头,任何其他内容都是不允许的!\n' + ) + model_config = None + try: + model_config = ModelConfig( + config_name="codefuse_default", + model_type=os.environ.get("DEFAULT_MODEL_TYPE"), + model_name=os.environ.get("DEFAULT_MODEL_NAME"), + api_key=os.environ.get("DEFAULT_API_KEY"), + api_url=os.environ.get("DEFAULT_API_URL"), + ) + memory = kwargs.get("memory") or "" + model = get_model(model_config) + content = model.predict(template.format(memory=memory)) + except Exception as e: + content = f"无法正确调用模型: {e}, {model_config}" + return content + + +class AgentZhangweiTool(GameActionTool): + name = "谁是卧底-张伟" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与谁是卧底这场游戏,在游戏中你的名字是李静" + ) + +class AgentLijingTool(GameActionTool): + name = "谁是卧底-李静" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与谁是卧底这场游戏,在游戏中你的名字是李静" + ) + + +class AgentWangpengTool(GameActionTool): + name = "谁是卧底-王鹏" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与谁是卧底这场游戏,在游戏中你的名字是李静" + ) + + +class GameEndCheckerTool(BaseToolModel): + name = "谁是卧底-胜利条件判断" + description = "谁是卧底的胜利条件判断工具,判断当前谁是卧底游戏是否结束。" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + content: str + + @classmethod + def run(cls, **kwargs) -> str: + """Execute your tool!""" + + template = ( + '##本局游戏历史记录##\n' + '{memory}\n\n' + '##背景##\n' + '你是一个逻辑判断大师,你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,现在需要判断当前游戏是否结束。\n\n' + '##任务##\n' + '你的任务是判断当前游戏是否结束,规则如下:\n' + '根据【重要信息】感知每一轮被投票死亡的玩家。 统计目前存活的[平民]玩家数量、[卧底]玩家数量。格式{{\"存活的卧底\":[player_name], \"存活的平民\":[player_name]}},判断以下条件中的一个是否满足:\n' + '1. \t卧底玩家全部已经死亡(即 存活[卧底]数量为0)。\n' + '2. 存活的[平民]数量与存活的[卧底]数量相等。\n' + '如果其中一个条件满足,则游戏结束;否则,游戏需要继续。\n\n' + '##输出##\n' + '返回jsonstr 格式。{{\"thought\": str, \"存活的玩家信息\": {{\"存活的卧底\":[player_name], \"存活的平民\":[player_name]}}, \"isEnd\": \"是\" or \"否\"}}\n' + '-thought **根据本局游戏历史记录** 分析 游戏最开始有哪些玩家, 他们的身份是什么, 投票导致死亡的玩家有哪些? 分析当前存活的玩家有哪些 ? 是否触发了游戏结束条件? 等等\n\n' + '##注意事项##\n' + '1. 所有玩家的座位、身份、agent_name、存活状态、游戏进展等信息在开头部分已给出。\n' + '2. \"是\" or \"否\" 如何选择?若游戏结束,则为\"是\",否则为\"否\"。\n' + '3. 请直接输出jsonstr,不用输出markdown格式。\n\n' + '4. 游戏可能进行了不只一轮,可能有1个或者2个玩家已经死亡,请注意感知\n' + '##结果##\n\n' + ) + + model_config = None + try: + model_config = ModelConfig( + config_name="codefuse_default", + model_type=os.environ.get("DEFAULT_MODEL_TYPE"), + model_name=os.environ.get("DEFAULT_MODEL_NAME"), + api_key=os.environ.get("DEFAULT_API_KEY"), + api_url=os.environ.get("DEFAULT_API_URL"), + ) + memory = kwargs.get("memory") or "" + model = get_model(model_config) + content = model.predict(template.format(memory=memory)) + except Exception as e: + content = f"无法正确调用模型: {e}, {model_config}" + return content + +class GameOutcomeCheckerTool(BaseToolModel): + name = "谁是卧底-结果输出" + description = "谁是卧底的结果输出工具,判断谁是卧底游戏中最终的胜利方是谁,并输出角色分配情况" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + content: str + + @classmethod + def run(cls, **kwargs) -> str: + """Execute your tool!""" + + template = ( + '##本局游戏历史记录##\n' + '{memory}\n' + '\n' + '##背景##\n' + '您正在参与“谁是卧底”这个游戏,角色是[主持人]。现在游戏已经结束,您需要判断胜利的一方是谁。\n' + '\n' + '##任务##\n' + '统计目前存活的[平民]玩家数量、[卧底]玩家数量。判断以下条件中的哪一个满足:\n' + '1.[卧底]数量为0。\n' + '2.[平民]数量与[卧底]数量相等。\n' + '如果条件1满足,则[平民]胜利;如果条件2满足,则[卧底]胜利。\n' + '\n' + '##输出##\n' + 'Python可直接解析的jsonstr,格式如下:\n' + '{{\"原因是\": 获胜者为[平民]或[卧底]的原因, \"角色分配结果为\": 所有玩家的身份和单词(根据本局游戏历史记录), \"获胜方为\": \"平民\" or \"卧底\"}}\n' + '以{{开头,任何其他内容都是不允许的!\n' + '\n' + '##输出示例##\n' + '{{\"原因是\": \"卧底数量为0\", \"角色分配结果为\": \"李静:身份为卧底,单词为香蕉;人类玩家:身份为平民, 单词为梨子; 张伟:身份为平民, 单词为梨子; 王鹏:身份为平民, 单词为梨子。\", \"获胜方为\": \"平民\"}}\n' + '\n' + '##注意##\n' + '请输出所有玩家的角色分配结果,不要遗漏信息\n' + '\n' + '##结果##\n\n' + ) + + model_config = None + try: + model_config = ModelConfig( + config_name="codefuse_default", + model_type=os.environ.get("DEFAULT_MODEL_TYPE"), + model_name=os.environ.get("DEFAULT_MODEL_NAME"), + api_key=os.environ.get("DEFAULT_API_KEY"), + api_url=os.environ.get("DEFAULT_API_URL"), + ) + memory = kwargs.get("memory") or "" + model = get_model(model_config) + content = model.predict(template.format(memory=memory)) + except Exception as e: + content = f"无法正确调用模型: {e}, {model_config}" + return content \ No newline at end of file diff --git a/muagent/tools/werewolf.py b/muagent/tools/werewolf.py new file mode 100644 index 0000000..6457f37 --- /dev/null +++ b/muagent/tools/werewolf.py @@ -0,0 +1,314 @@ +import os +from typing import ( + List, + Dict +) +from loguru import logger +from pydantic import BaseModel, Field +import random + + +from .base_tool import BaseToolModel +from ..models import get_model, ModelConfig + + + +class RoleAssignmentTool(BaseToolModel): + name = "狼人杀-角色分配工具" + description = "狼人杀的角色分配工具,可以为每一位玩家分配一个单词和人物角色。" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + roles: list + + @classmethod + def run(cls, **kwargs) -> ToolOutputArgs: + """Execute your tool!""" + + players = [ + ["朱丽", "agent_朱丽"], + ["周杰", "agent_周杰"], + ["沈强", "agent_沈强"], + ["韩刚", "agent_韩刚"], + ["梁军", "agent_梁军"], + ["周欣怡", "agent_周欣怡"], + ["贺子轩", "agent_贺子轩"], + ["人类玩家", "agent_人类玩家"] + ] + random.shuffle(players) + roles = ["平民_1", "平民_2", "平民_3", "狼人_1", "狼人_2", "狼人_3", "女巫", "预言家"] + random.shuffle(roles) + + assigned_roles = [] + for i in range(len(players)): + assigned_roles.append({ + "player_name": players[i][0], + "agent_name": players[i][1], + "agent_description": roles[i] + }) + return assigned_roles + + + +class PlayerSeatingTool(BaseToolModel): + name: str = "狼人杀-座位分配" + description: str = "狼人杀的座位分配工具,可以将玩家顺序打乱随机分配座位" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + seating_chart: str + + @classmethod + def run(cls, **kwargs) -> ToolOutputArgs: + """Execute your tool!""" + + players = [ + ["朱丽", "agent_朱丽"], + ["周杰", "agent_周杰"], + ["沈强", "agent_沈强"], + ["韩刚", "agent_韩刚"], + ["梁军", "agent_梁军"], + ["周欣怡", "agent_周欣怡"], + ["贺子轩", "agent_贺子轩"], + ["人类玩家", "agent_人类玩家"] + ] + n = len(players) + random.shuffle(players) + + seating_chart = "\n\n| 座位 | 玩家 |\n|---|---|\n" + seating_chart += "\n".join(f"| {i} | **{players[i-1][0]}** |" for i in range(1, n + 1)) + + return seating_chart + + +class WerewolfGameInstructionTool(BaseToolModel): + name = "狼人杀-游戏指令" + description = "狼人杀的游戏指令工具,需要根据记忆的上下文信息,当前任务、以及你拿到的身份信息来进行回应。" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + instruction: str + + @classmethod + def run(cls, **kwargs) -> ToolOutputArgs: + """Execute your tool!""" + template = ( + '##狼人杀游戏说明##\n' + '这个游戏基于文字交流, 以下是游戏规则:\n' + '角色:\n' + '主持人也是游戏的组织者,你需要正确回答他的指示。游戏中有五种角色:狼人、平民、预言家、女巫和猎人,三个狼人,一个预言家,一个女巫,一个猎人,两个平民。\n' + '好人阵营: 村民、预言家、猎人和女巫。\n' + '游戏阶段:游戏分为两个交替的阶段:白天和黑夜。\n' + '黑夜:\n' + '在黑夜阶段,你与主持人的交流内容是保密的。你无需担心其他玩家和主持人知道你说了什么或做了什么。\n' + '- 如果你是狼人,你需要和队友一起选择袭击杀死一个玩家\n' + '- 如果你是女巫,你有一瓶解药,可以拯救被狼人袭击的玩家,以及一瓶毒药,可以在黑夜后毒死一个玩家。解药和毒药只能使用一次。\n' + '- 如果你是预言家,你可以在每个晚上检查一个玩家是否是狼人,这非常重要。\n' + '- 如果你是猎人,当你在黑夜被狼人杀死时可以选择开枪杀死任意一名玩家。\n' + '- 如果你是村民,你在夜晚无法做任何事情。\n' + '白天:\n' + '你与存活所有玩家(包括敌人)讨论。讨论结束后,玩家投票来淘汰一个自己怀疑是狼人的玩家。获得最多票数的玩家将被淘汰。主持人将告诉谁被杀,否则将没有人被杀。\n' + '如果你是猎人,当你在白天被投票杀死之后可以选择开枪杀死任意一名玩家。\n' + '游戏目标:\n' + '狼人的目标是杀死所有的好人阵营中的玩家,并且不被好人阵营的玩家识别出狼人身份;\n' + '好人阵营的玩家,需要找出并杀死所有的狼人玩家。\n' + '##注意##\n' + '你正在参与狼人杀这个游戏,你应该感知自己的名字、座位号和角色。\n' + '1. 若你的角色为狼人,白天的发言应该尽可能隐藏身份。\n' + '2. 若你的角色属于好人阵营,白天的发言应该根据游戏进展尽可能分析出谁是狼人。\n' + '##以下为目前游戏进展##\n' + '{memory}\n' + '##发言格式##\n' + '你的回答中需要包含你的想法并给出简洁的理由,注意请有理有据,白天的发言尽量不要与别人的发言内容重复。发言的格式应该为Python可直接解析的jsonstr,格式如下:\n' + '{{\"thought\": 以“我是【座位号】号玩家【名字】【角色】”开头,根据主持人的通知感知自己的【名字】、【座位号】、【角色】,根据游戏进展和自己游戏角色的当前任务分析如何发言,字数不超过150字, \"output\": 您的发言应该符合目前游戏进展和自己角色的逻辑,白天投票环节不能投票给自己。}}\n' + '##开始发言##\n' + ) + model_config = None + try: + model_config = ModelConfig( + config_name="codefuse_default", + model_type=os.environ.get("DEFAULT_MODEL_TYPE"), + model_name=os.environ.get("DEFAULT_MODEL_NAME"), + api_key=os.environ.get("DEFAULT_API_KEY"), + api_url=os.environ.get("DEFAULT_API_URL"), + ) + memory = kwargs.get("memory") or "" + model = get_model(model_config) + content = model.predict(template.format(memory=memory)) + except Exception as e: + content = f"无法正确调用模型: {e}, {model_config}" + return content + + +class AgentZhuliTool(WerewolfGameInstructionTool): + name = "狼人杀-agent_朱丽" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是朱丽" + ) + + +class AgentZhoujieTool(WerewolfGameInstructionTool): + name = "狼人杀-agent_周杰" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是周杰" + ) + + +class AgentShenqiangTool(WerewolfGameInstructionTool): + name = "狼人杀-agent_沈强" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是沈强" + ) + +class AgentHangangTool(WerewolfGameInstructionTool): + name = "狼人杀-agent_韩刚" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是韩刚" + ) + + +class AgentLiangjunTool(WerewolfGameInstructionTool): + name = "狼人杀-agent_梁军" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是梁军" + ) + +class AgentZhouxinyiTool(WerewolfGameInstructionTool): + name = "狼人杀-agent_周欣怡" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是周欣怡" + ) + + +class AgentHezixuanTool(WerewolfGameInstructionTool): + name = "狼人杀-agent_贺子轩" + description = ( + f"你是一个智能体(Agent),你正在模拟玩家参与狼人杀这场游戏,在游戏中你的名字是贺子轩" + ) + + +class WerewolfGameEndCheckerTool(BaseToolModel): + name = "狼人杀-胜利条件判断" + description = "狼人杀的胜利条件判断工具,判断当前狼人杀游戏是否结束。" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + thought: str + players: dict + isEnd: str + + @classmethod + def run(cls, **kwargs) -> ToolOutputArgs: + """Execute your tool!""" + template = ( + '##本局游戏历史记录##\n' + '{memory}\n' + '\n' + '##背景##\n' + '你是一个逻辑判断大师,你正在参与“狼人杀”这个游戏,你的角色是[主持人]。你熟悉“狼人杀”游戏的完整流程,现在需要判断当前游戏是否结束。\n' + '\n' + '##任务##\n' + '你的任务是判断当前游戏是否结束,规则如下:\n' + '根据【重要信息】感知每一轮被投票死亡、被狼人杀死、被女巫毒死、被猎人带走的玩家。 统计目前存活的[好人]玩家数量、[狼人]玩家数量。格式{{\"存活的好人\":[player_name], \"存活的狼人\":[player_name]}},判断以下条件中的一个是否满足:\n' + '1. 存活的“狼人”玩家数量为0。\n' + '2. “狼人”数量超过了“好人”数量。\n' + '3. “狼人”数量等于“好人”数量,“女巫”已死亡或者她的毒药已经使用。\n' + '若某个条件满足,游戏结束;否则游戏没有结束。\n' + '\n' + '##输出##\n' + '返回JSON格式,格式为:{{\"thought\": str, \"存活的玩家信息\": {{\"存活的好人\":[player_name], \"存活的狼人\":[player_name]}}, \"isEnd\": \"是\" or \"否\"}}\n' + '-thought **根据本局游戏历史记录** 分析 游戏最开始有哪些玩家, 他们的身份是什么, 投票导致死亡的玩家有哪些? 被狼人杀死的玩家有哪些? 被女巫毒死的玩家是谁? 被猎人带走的玩家是谁?分析当前存活的玩家有哪些? 是否触发了游戏结束条件? 等等。\n' + '\n' + '##example##\n' + '{{\"thought\": \"**游戏开始时** 有 小杭、小北、小赵、小钱、小孙、小李、小夏、小张 八位玩家, 其中 小杭、小北、小赵是[狼人], 小钱、小孙是[平民], 小李是[预言家],小夏是[女巫],小张是[猎人],小张在第一轮被狼人杀死了,猎人没有开枪,[狼人]数量大于[好人]数量,因此游戏未结束。\", \"存活的玩家信息\": {{\"存活的狼人\":[\"小杭\", \"小北\", \"小赵\"]}}, {{\"存活的好人\":[\"小钱\", \"小孙\", \"小李\", \"小夏\"]}}, \"isEnd\": \"否\" }}\n' + '##注意事项##\n' + '1. 所有玩家的座位、身份、agent_name、存活状态、游戏进展等信息在开头部分已给出。\n' + '2. \"是\" or \"否\" 如何选择?若游戏结束,则为\"是\",否则为\"否\"。\n' + '3. 请直接输出jsonstr,不用输出markdown格式。\n' + '4. 游戏可能进行了不只一轮,可能有1个或者2个玩家已经死亡,请注意感知。\n' + ) + model_config = None + try: + model_config = ModelConfig( + config_name="codefuse_default", + model_type=os.environ.get("DEFAULT_MODEL_TYPE"), + model_name=os.environ.get("DEFAULT_MODEL_NAME"), + api_key=os.environ.get("DEFAULT_API_KEY"), + api_url=os.environ.get("DEFAULT_API_URL"), + ) + memory = kwargs.get("memory") or "" + model = get_model(model_config) + content = model.predict(template.format(memory=memory)) + except Exception as e: + content = f"无法正确调用模型: {e}, {model_config}" + return content + + + +class WerewolfGameOutcomeTool(BaseToolModel): + name = "狼人杀-结果输出" + description = "狼人杀的结果输出工具,判断狼人杀游戏中最终的胜利方是谁。" + + class ToolInputArgs(BaseModel): + pass + + class ToolOutputArgs(BaseModel): + reason: str + 角色分配结果为: str + 获胜方为: str + + @classmethod + def run(cls, **kwargs) -> ToolOutputArgs: + """Execute your tool!""" + template = ( + '##本局游戏历史记录##\n' + '{memory}\n' + '\n' + '##背景##\n' + '您正在参与“狼人杀”这个游戏,角色是[主持人]。现在游戏已经结束,您需要判断胜利的一方是谁。\n' + '\n' + '##任务##\n' + '统计目前存活的[好人]玩家数量、[狼人]玩家数量。判断以下条件中的哪一个满足:\n' + '1. 存活的“狼人”玩家数量为0。\n' + '2. “狼人”数量超过了“好人”数量。\n' + '3. “狼人”数量等于“好人”数量,“女巫”已死亡或者她的毒药已经使用。\n' + '如果条件1满足,则[好人]胜利;如果条件2或者条件3满足,则[狼人]胜利。\n' + '\n' + '##输出##\n' + 'Python可直接解析的jsonstr,格式如下:\n' + '{{\"原因是\": 获胜者为[好人]或[狼人]的原因, \"角色分配结果为\": 所有玩家的角色(根据本局游戏历史记录), \"获胜方为\": \"好人\" or \"狼人\"}}\n' + '以{{开头,任何其他内容都是不允许的!\n' + '\n' + '##输出示例##\n' + '{{\"原因是\": \"狼人数量为0\", \"角色分配结果为\": \"沈强:身份为狼人_1;周欣怡:身份为狼人_2;梁军:身份为狼人_3;贺子轩:身份为平民_1;人类玩家:身份为平民_2;朱丽:身份为预言家;韩刚:身份为女巫;周杰:身份为猎人。\", \"获胜方为\": \"好人\"}}\n' + '\n' + '##注意##\n' + '请输出所有玩家的角色分配结果,不要遗漏信息。\n' + '\n' + '##结果##\n' + '\n' + ) + + model_config = None + try: + model_config = ModelConfig( + config_name="codefuse_default", + model_type=os.environ.get("DEFAULT_MODEL_TYPE"), + model_name=os.environ.get("DEFAULT_MODEL_NAME"), + api_key=os.environ.get("DEFAULT_API_KEY"), + api_url=os.environ.get("DEFAULT_API_URL"), + ) + memory = kwargs.get("memory") or "" + model = get_model(model_config) + content = model.predict(template.format(memory=memory)) + except Exception as e: + content = f"无法正确调用模型: {e}, {model_config}" + return content \ No newline at end of file diff --git a/muagent/utils/common_utils.py b/muagent/utils/common_utils.py index 977420b..5939390 100644 --- a/muagent/utils/common_utils.py +++ b/muagent/utils/common_utils.py @@ -31,6 +31,13 @@ def addMinutesToTime(input_time: str, n: int = 5, dateformat=DATE_FORMAT): new_time_after = dt + timedelta(minutes=n) return new_time_before.strftime(dateformat), new_time_after.strftime(dateformat) +def addMinutesToTimestamp(input_time: str, n: int = 5, dateformat=DATE_FORMAT): + dt = datetime.strptime(input_time, dateformat) + + # 前后加N分钟 + new_time_before = dt - timedelta(minutes=n) + new_time_after = dt + timedelta(minutes=n) + return new_time_before.timestamp(), new_time_after.timestamp() def timestampToDateformat(ts, interval=1000, dateformat=DATE_FORMAT): '''将标准时间戳转换标准指定时间格式''' @@ -131,6 +138,44 @@ def double_hashing(s: str, modulus: int = 10e12) -> int: return int((hash1 + hash2) % modulus) +def _convert_to_str(content: Any) -> str: + """Convert the content to string. + + The implementation of this _convert_to_str are borrowed from + https://github.com/modelscope/agentscope/blob/main/src/agentscope/utils/common.py + + Note: + For prompt engineering, simply calling `str(content)` or + `json.dumps(content)` is not enough. + + - For `str(content)`, if `content` is a dictionary, it will turn double + quotes to single quotes. When this string is fed into prompt, the LLMs + may learn to use single quotes instead of double quotes (which + cannot be loaded by `json.loads` API). + + - For `json.dumps(content)`, if `content` is a string, it will add + double quotes to the string. LLMs may learn to use double quotes to + wrap strings, which leads to the same issue as `str(content)`. + + To avoid these issues, we use this function to safely convert the + content to a string used in prompt. + + Args: + content (`Any`): + The content to be converted. + + Returns: + `str`: The converted string. + """ + + if isinstance(content, str): + return content + elif isinstance(content, (dict, list, int, float, bool, tuple)): + return json.dumps(content, ensure_ascii=False) + else: + return str(content) + + @contextlib.contextmanager def timer(seconds: Optional[Union[int, float]] = None) -> Generator: """ diff --git a/requirements.txt b/requirements.txt index 115cf6a..47a6dfb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ Pyarrow python-magic-bin; sys_platform == 'win32' SQLAlchemy==2.0.19 docker -Levenshtein +edit_distance redis==5.0.1 pydantic<=1.10.14 aliyun-log-python-sdk==0.9.0 diff --git a/setup.py b/setup.py index 1553d9c..7fb0ef2 100644 --- a/setup.py +++ b/setup.py @@ -35,8 +35,12 @@ "notebook", "docker", "sseclient", - "Levenshtein", + "edit_distance", "urllib3==1.26.6", + "ollama", + "colorama", + "pycryptodome", + "dashscope" # "chromadb==0.4.17", "javalang==0.13.0", diff --git a/tests/agents/funccall_agent_test.py b/tests/agents/funccall_agent_test.py new file mode 100644 index 0000000..256e173 --- /dev/null +++ b/tests/agents/funccall_agent_test.py @@ -0,0 +1,97 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.schemas import Message, Memory +from muagent.agents import FunctioncallAgent +from muagent import get_agent, get_project_config_from_env + + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "0" + +AGENT_CONFIGS = { + "codefuse_function_caller": { + "config_name": "codefuse_function_caller", + "agent_type": "FunctioncallAgent", + "agent_name": "codefuse_function_caller", + "llm_config_name": "qwener" + } +} +os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) +project_config = get_project_config_from_env() +tools = ["KSigmaDetector", "MetricsQuery"] +tools = [ + "谁是卧底-座位分配", "谁是卧底-角色分配", "谁是卧底-结果输出", "谁是卧底-胜利条件判断", + "谁是卧底-张伟", "谁是卧底-李静", "谁是卧底-王鹏", +] + +# tools = [ +# "狼人杀-角色分配工具", "狼人杀-座位分配", "狼人杀-胜利条件判断", "狼人杀-结果输出", +# '狼人杀-agent_朱丽', '狼人杀-agent_周杰', '狼人杀-agent_沈强', '狼人杀-agent_韩刚', +# '狼人杀-agent_梁军', '狼人杀-agent_周欣怡', '狼人杀-agent_贺子轩' +# ] + +agent = FunctioncallAgent( + agent_name="codefuse_function_caller", + project_config=project_config, + tools=tools +) + + +memory_content = "[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.89, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.89, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.89, 26.789, 28.901]" +memory = Memory( + messages=[Message( + role_type="observation", + content=memory_content + )] +) +query_content = "帮我查询下127.0.0.1这个服务器的在10点的数据" +query_content = "帮我判断这个数据是否异常" +query_content = "开始分配座位" +query_content = "开始分配身份" +query_content = "游戏是否结束" +query_content = "游戏的胜利玩家是谁" + +memory_content = "3号玩家说今天天气很好" +memory = Memory( + messages=[Message( + role_type="observation", + content=memory_content + )] +) +query_content = "我要使用工具,工具描述为agent_张伟" +# query_content = "我要使用工具,工具描述为'agent_周杰'" + + +query = Message( + role_name="human", + role_type="user", + content=query_content, +) +# agent.pre_print(query) +# output_message = agent.step(query, memory=memory) +output_message = agent.step(query, extra_params={"memory": memory_content}) +print("### intput ###\n", output_message.input_text) +print("### content ###\n", output_message.content) +print("### observation ###\n", output_message.parsed_contents[-1]["Observation"]) +print("### step content ###\n", output_message.step_content) \ No newline at end of file diff --git a/tests/agents/group_agent_test.py b/tests/agents/group_agent_test.py new file mode 100644 index 0000000..182aa9c --- /dev/null +++ b/tests/agents/group_agent_test.py @@ -0,0 +1,71 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from muagent.tools import TOOL_SETS +from muagent.schemas import Message +from muagent.agents import BaseAgent +from muagent.project_manager import get_project_config_from_env + + +tools = list(TOOL_SETS) +tools = ["KSigmaDetector", "MetricsQuery"] +role_prompt = "you are a helpful assistant!" + +AGENT_CONFIGS = { + "grouper": { + "agent_type": "GroupAgent", + "agent_name": "grouper", + "agents": ["codefuse_reacter_1", "codefuse_reacter_2"] + }, + "codefuse_reacter_1": { + "agent_type": "ReactAgent", + "agent_name": "codefuse_reacter_1", + "tools": tools, + }, + "codefuse_reacter_2": { + "agent_type": "ReactAgent", + "agent_name": "codefuse_reacter_2", + "tools": tools, + } +} +os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "0" + +# +project_config = get_project_config_from_env() +agent = BaseAgent.init_from_project_config( + "grouper", project_config +) + +query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" +query = Message( + role_name="human", + role_type="user", + content=query_content, +) +# agent.pre_print(query) +output_message = agent.step(query) +print("input:", output_message.input_text) +print("content:", output_message.content) +print("step_content:", output_message.step_content) \ No newline at end of file diff --git a/tests/agents/react_agent_test.py b/tests/agents/react_agent_test.py new file mode 100644 index 0000000..3703812 --- /dev/null +++ b/tests/agents/react_agent_test.py @@ -0,0 +1,62 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from muagent.tools import TOOL_SETS +from muagent.schemas import Message +from muagent.agents import BaseAgent +from muagent import get_project_config_from_env + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "0" + +tools = list(TOOL_SETS) +tools = ["KSigmaDetector", "MetricsQuery"] +role_prompt = "you are a helpful assistant!" + +AGENT_CONFIGS = { + "reacter": { + "system_prompt": role_prompt, + "agent_type": "ReactAgent", + "agent_name": "reacter", + "tools": tools, + "llm_config_name": "qwen_chat" + } +} +os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + +# +project_config = get_project_config_from_env() +agent = BaseAgent.init_from_project_config( + "reacter", project_config +) + +query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" +query = Message( + role_name="human", + role_type="user", + content=query_content, +) +# agent.pre_print(query) +output_message = agent.step(query) +print("### intput ###\n", output_message.input_text) +print("### content ###\n", output_message.content) +print("### step content ###\n", output_message.step_content) \ No newline at end of file diff --git a/tests/agents/single_agent_test.py b/tests/agents/single_agent_test.py new file mode 100644 index 0000000..ededa33 --- /dev/null +++ b/tests/agents/single_agent_test.py @@ -0,0 +1,119 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from muagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS + +from muagent.schemas import Message +from muagent.models import ModelConfig +from muagent.agents import SingleAgent, BaseAgent +from muagent import get_project_config_from_env + + + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "0" + +role_prompt = "you are a helpful assistant!" +role_prompt = """#### AGENT PROFILE +you are a helpful assistant! + +#### RESPONSE OUTPUT FORMAT +**Action Status:** Set to 'stopped' or 'code_executing'. +If it's 'stopped', the action is to provide the final answer to the session records and executed steps. +If it's 'code_executing', the action is to write the code. + +**Action:** +```python +# Write your code here +... +``` +""" + +role_prompt = """#### AGENT PROFILE +you are a helpful assistant! + +#### RESPONSE OUTPUT FORMAT +**Action Status:** Set to either 'stopped' or 'tool_using'. If 'stopped', provide the final response to the original question. If 'tool_using', proceed with using the specified tool. + +**Action:** Use the tools by formatting the tool action in JSON. The format should be: + +```json +{ + "tool_name": "$TOOL_NAME", + "tool_params": "$INPUT" +} +``` +""" + +tools = list(TOOL_SETS) +tools = ["KSigmaDetector", "MetricsQuery"] + + +AGENT_CONFIGS = { + "codefuse_simpler": { + "agent_type": "SingleAgent", + "agent_name": "codefuse_simpler", + "tools": tools, + "llm_config_name": "qwen_chat" + } +} +os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + + +project_config = get_project_config_from_env() +agent = BaseAgent.init_from_project_config( + "codefuse_simpler", project_config +) +# base_agent = SingleAgent( +# system_prompt=role_prompt, +# project_config=project_config, +# tools=tools +# ) + + +question = "用python画一个爱心" +query = Message( + session_index="agent_test", + role_type="user", + role_name="user", + content=question, +) + +# base_agent.pre_print(query) +# output_message = base_agent.step(query) +# print(output_message.input_text) +# print(output_message.content) + + + + +query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" +query = Message( + role_name="human", + role_type="user", + input_text=query_content, +) +# base_agent.pre_print(query) +output_message = agent.step(query) +print("### intput ###\n", output_message.input_text) +print("### content ###\n", output_message.content) +print("### step content ###\n", output_message.step_content) \ No newline at end of file diff --git a/tests/agents/task_agent_test.py b/tests/agents/task_agent_test.py new file mode 100644 index 0000000..c446ed2 --- /dev/null +++ b/tests/agents/task_agent_test.py @@ -0,0 +1,66 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from muagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS + +from muagent.schemas import Message +from muagent.models import ModelConfig +from muagent.agents import BaseAgent +from muagent import get_project_config_from_env + + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "0" + + +tools = list(TOOL_SETS) +tools = ["KSigmaDetector", "MetricsQuery"] +role_prompt = "you are a helpful assistant!" + +AGENT_CONFIGS = { + "tasker": { + "system_prompt": role_prompt, + "agent_type": "TaskAgent", + "agent_name": "tasker", + "tools": tools, + "llm_config_name": "qwen_chat" + } +} +os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + +# +project_config = get_project_config_from_env() +agent = BaseAgent.init_from_project_config( + "tasker", project_config +) + +query_content = "先帮我获取下127.0.0.1这个服务器在10点的数,然后在帮我判断下数据是否存在异常" +query = Message( + role_name="human", + role_type="user", + content=query_content, +) +# agent.pre_print(query) +output_message = agent.step(query) +print("### intput ###\n", output_message.input_text) +print("### content ###\n", output_message.content) +print("### step content ###\n", output_message.step_content) \ No newline at end of file diff --git a/tests/llm_models/embedding_test.py b/tests/llm_models/embedding_test.py new file mode 100644 index 0000000..3e3a284 --- /dev/null +++ b/tests/llm_models/embedding_test.py @@ -0,0 +1,48 @@ +from loguru import logger +import os, sys +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) +try: + import test_config + api_key = os.environ["OPENAI_API_KEY"] + api_base_url= os.environ["API_BASE_URL"] + model_name = os.environ["model_name"] + embed_model = os.environ["embed_model"] + embed_model_path = os.environ["embed_model_path"] +except Exception as e: + # set your config + api_key = "" + api_base_url= "" + model_name = "" + embed_model = "" + embed_model_path = "" + logger.error(f"{e}") + + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from muagent.models import get_model +from muagent.schemas.models import ModelConfig +import json + +model_configs = json.loads(os.environ["MODEL_CONFIGS"]) + +for model_type in model_configs.keys(): + if "_embedding" not in model_type: continue + model_config = model_configs[model_type] + embed_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=model_config["model_name"], + api_key=model_config["api_key"], + ) + + model = get_model(embed_config) + + + print(model_type, model_config["model_name"], len(model.embed_query("hello"))) \ No newline at end of file diff --git a/tests/llm_models/model_test.py b/tests/llm_models/model_test.py new file mode 100644 index 0000000..f2e4471 --- /dev/null +++ b/tests/llm_models/model_test.py @@ -0,0 +1,93 @@ +from loguru import logger +import os, sys +import json + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) +try: + import test_config + api_key = os.environ["OPENAI_API_KEY"] + api_base_url= os.environ["API_BASE_URL"] + model_name = os.environ["model_name"] + embed_model = os.environ["embed_model"] + embed_model_path = os.environ["embed_model_path"] +except Exception as e: + # set your config + api_key = "" + api_base_url= "" + model_name = "" + embed_model = "" + embed_model_path = "" + logger.error(f"{e}") + + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from muagent.models import get_model +from muagent.schemas.models import ModelConfig + +model_configs = json.loads(os.environ["MODEL_CONFIGS"]) + +# "openai_chat","yi_chat","qwen_chat", "dashscope_chat""moonshot_chat", "ollama_chat" + +model_type = "ollama_chat" +model_config = model_configs[model_type] + +model_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=model_config["model_name"], + api_key=model_config["api_key"], +) +model = get_model(model_config) + +# 工具 +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location", "unit"], + "additionalProperties": False, + }, + }, + } +] + + +# print(model.generate("输出 '今天你好'", stop="你", format_type='str')) +for i in model.generate_stream("hello", stop="你", format_type='str'): + print(i) + +# # +# print(model.generate("hello", format_type='str')) + +# # +# for i in model.generate_stream("hello", format_type='str'): +# print(i) + +# # +# print(model.chat([{"role": "user", "content":"hello"}], format_type='str')) + +# # +# for i in model.chat_stream([{"role": "user", "content":"hello"}], format_type='str'): +# print(i) + +# # +# print(model.function_call(tools=tools, prompt="我想查北京的天气")) + +# # +# for i in model.function_call_stream(tools=tools, messages=[{"role": "user", "content":"我想查北京的天气"}]): +# print(i) \ No newline at end of file diff --git a/tests/memory_manager/local_memory_manager_test.py b/tests/memory_manager/local_memory_manager_test.py new file mode 100644 index 0000000..2c5ac16 --- /dev/null +++ b/tests/memory_manager/local_memory_manager_test.py @@ -0,0 +1,161 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config + api_key = os.environ["OPENAI_API_KEY"] + api_base_url= os.environ["API_BASE_URL"] + model_name = os.environ["model_name"] + model_engine = os.environ["model_engine"] + embed_model = os.environ["embed_model"] + embed_model_path = os.environ["embed_model_path"] +except Exception as e: + # set your config + api_key = "" + api_base_url= "" + model_name = "" + model_engine = os.environ["model_engine"] + embed_model = "" + embed_model_path = "" + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.utils.common_utils import getCurrentDatetime +from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig +from muagent.schemas import Message +from muagent.models import ModelConfig, get_model + +from muagent.memory_manager import LocalMemoryManager + + +from muagent.llm_models.llm_config import EmbedConfig, LLMConfig + +model_configs = json.loads(os.environ["MODEL_CONFIGS"]) + +# +# llm_config = LLMConfig( +# model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3, +# ) +model_type = "qwen_chat" +model_config = model_configs[model_type] +model_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=model_config["model_name"], + api_key=model_config["api_key"], +) + + +# embed_config = EmbedConfig( +# embed_engine="model", embed_model=embed_model, embed_model_path=embed_model_path +# ) +model_type = "qwen_text_embedding" +embed_config = model_configs[model_type] +embed_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=embed_config["model_name"], + api_key=embed_config["api_key"], +) + +# prepare your message +message1 = Message( + session_index="default", + role_name="test1", + role_type="user", + content="hello", + spec_parsed_contents=[{"input": "hello"}], +) + +text = "hi! how can I help you?" +message2 = Message( + session_index="shuimo", + role_name="test2", + role_type="assistant", + content=text, + arsed_output_list=[{"answer": text}], +) + +text = "they say hello and hi to each other" +message3 = Message( + session_index="shanshi", + role_name="test3", + role_type="summary", + content=text, + spec_parsed_contents=[{"summary": text}], +) + +vb_config = VBConfig(vb_type="LocalFaissHandler") + +# append or extend test +print("###"*10 + "append or extend" + "###"*10) +local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=True) +# append can ignore user_name +local_memory_manager.append(message=message1) +local_memory_manager.append(message=message2) +local_memory_manager.append(message=message3) + +# test init_local +print("###"*10 + "dont load local" + "###"*10) +local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=True) +print(local_memory_manager.get_memory_pool("default").to_format_messages( + content_key="content", format_type='str')) +print(local_memory_manager.get_memory_pool("shuimo").to_format_messages( + content_key="content", format_type='str')) +print(local_memory_manager.get_memory_pool("shanshi").to_format_messages( + content_key="content", format_type='str')) + +# test load from local +print("###"*10 + "load local" + "###"*10) +local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=False) +print(local_memory_manager.get_memory_pool("default").to_format_messages( + content_key="content", format_type='str')) +print(local_memory_manager.get_memory_pool("shuimo").to_format_messages( + content_key="content", format_type='str')) +print(local_memory_manager.get_memory_pool("shanshi").to_format_messages( + content_key="content", format_type='str')) + + + +local_memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=False) +# embedding retrieval test +print("###"*10 + "retrieval" + "###"*10) +text = "say hi to each other," +# retrieval_type=datetime => retrieval from datetime and jieba +print(local_memory_manager.router_retrieval( + session_index="shanshi", text=text, datetime=getCurrentDatetime(), + n=4, top_k=5, retrieval_type= "datetime")) +# retrieval_type=embedding => retrieval from embedding +print(local_memory_manager.router_retrieval( + session_index="shanshi", text=text, top_k=5, retrieval_type= "embedding")) +# retrieval_type=text => retrieval from jieba +print(local_memory_manager.router_retrieval( + session_index="shanshi", text=text, top_k=5, retrieval_type= "text")) + +# # recursive_summary test +print("###"*10 + "recursive_summary" + "###"*10) +print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shanshi").messages, split_n=1, session_index="shanshi")) + +# print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shuimo").messages, split_n=1, session_index="shanshi")) + +# print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("default").messages, split_n=1, session_index="shanshi")) + + +# test after clear local vs and jsonl +print("###"*10 + "test after clear local vs and jsonl" + "###"*10) +local_memory_manager.clear_local(re_init=True) +print(local_memory_manager.get_memory_pool("shanshi").to_format_messages( + content_key="content", format_type='str')) \ No newline at end of file diff --git a/tests/memory_manager/local_mm_crud_test.py b/tests/memory_manager/local_mm_crud_test.py new file mode 100644 index 0000000..d5211cb --- /dev/null +++ b/tests/memory_manager/local_mm_crud_test.py @@ -0,0 +1,117 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config + api_key = os.environ["OPENAI_API_KEY"] + api_base_url= os.environ["API_BASE_URL"] + model_name = os.environ["model_name"] + model_engine = os.environ["model_engine"] + embed_model = os.environ["embed_model"] + embed_model_path = os.environ["embed_model_path"] +except Exception as e: + # set your config + api_key = "" + api_base_url= "" + model_name = "" + model_engine = os.environ["model_engine"] + embed_model = "" + embed_model_path = "" + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.utils.common_utils import getCurrentDatetime +from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig +from muagent.schemas import Message +from muagent.models import ModelConfig, get_model + +from muagent.memory_manager import LocalMemoryManager, TbaseMemoryManager + + +from muagent.llm_models.llm_config import EmbedConfig, LLMConfig + +model_configs = json.loads(os.environ["MODEL_CONFIGS"]) + +# +# llm_config = LLMConfig( +# model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3, +# ) +model_type = "qwen_chat" +model_config = model_configs[model_type] +model_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=model_config["model_name"], + api_key=model_config["api_key"], +) + + +# embed_config = EmbedConfig( +# embed_engine="model", embed_model=embed_model, embed_model_path=embed_model_path +# ) +model_type = "qwen_text_embedding" +embed_config = model_configs[model_type] +embed_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=embed_config["model_name"], + api_key=embed_config["api_key"], +) + + +# 初始化 TbaseHandler 实例 +tb_config = TBConfig( + tb_type="TbaseHandler", + index_name="muagent_test", + host="127.0.0.1", + port=os.environ['tb_port'], + username=os.environ['tb_username'], + password=os.environ['tb_password'], +) + +vb_config = VBConfig(vb_type="LocalFaissHandler") + +# append or extend test +# memory_manager = LocalMemoryManager(embed_config=embed_config, llm_config=model_config, vb_config=vb_config, do_init=True) +memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config) + + +# prepare your message +message1 = Message( + session_index="default", + message_index="default", + role_name="crud_test", + role_type="user", + content="hello", + role_tags=["shanshi"] +) + +# append can ignore user_name +memory_manager.append(message=message1) +print(memory_manager.get_memory_pool("default").to_format_messages(format_type="raw")) + +# prepare your message +message2 = Message( + session_index="default", + message_index="default", + role_name="crud_test", + role_type="user", + content="hello", + role_tags=["test"] +) + +memory_manager.append(message=message2, role_tag="test") +print(memory_manager.get_memory_pool("default").to_format_messages(format_type="raw")) \ No newline at end of file diff --git a/tests/memory_manager/tbase_memory_manager_test.py b/tests/memory_manager/tbase_memory_manager_test.py new file mode 100644 index 0000000..f899cac --- /dev/null +++ b/tests/memory_manager/tbase_memory_manager_test.py @@ -0,0 +1,153 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config + api_key = os.environ["OPENAI_API_KEY"] + api_base_url= os.environ["API_BASE_URL"] + model_name = os.environ["model_name"] + model_engine = os.environ["model_engine"] + embed_model = os.environ["embed_model"] + embed_model_path = os.environ["embed_model_path"] +except Exception as e: + # set your config + api_key = "" + api_base_url= "" + model_name = "" + model_engine = os.environ["model_engine"] + embed_model = "" + embed_model_path = "" + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.llm_models.llm_config import EmbedConfig, LLMConfig +from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig +from muagent.schemas import Message +from muagent.models import ModelConfig, get_model + +from muagent.memory_manager import TbaseMemoryManager +from muagent.utils.common_utils import getCurrentDatetime + +model_configs = json.loads(os.environ["MODEL_CONFIGS"]) + +# +# llm_config = LLMConfig( +# model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3, +# ) +model_type = "qwen_chat" +model_config = model_configs[model_type] +model_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=model_config["model_name"], + api_key=model_config["api_key"], +) + + +# embed_config = EmbedConfig( +# embed_engine="model", embed_model=embed_model, embed_model_path=embed_model_path +# ) +model_type = "qwen_text_embedding" +embed_config = model_configs[model_type] +embed_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=embed_config["model_name"], + api_key=embed_config["api_key"], +) + + + +# 初始化 TbaseHandler 实例 +tb_config = TBConfig( + tb_type="TbaseHandler", + index_name="muagent_test", + host="127.0.0.1", + port=os.environ['tb_port'], + username=os.environ['tb_username'], + password=os.environ['tb_password'], +) + +# prepare your message +message1 = Message( + session_index="default", + role_name="test1", + role_type="user", + content="hello", + spec_parsed_contents=[{"input": "hello"}], +) + +text = "hi! how can I help you?" +message2 = Message( + session_index="shuimo", + role_name="test2", + role_type="assistant", + content=text, + arsed_output_list=[{"answer": text}], +) + +text = "they say hello and hi to each other" +message3 = Message( + session_index="shanshi", + role_name="test3", + role_type="summary", + content=text, + spec_parsed_contents=[{"summary": text}], +) + + +# # append or extend test +# print("###"*10 + "append or extend" + "###"*10) +# local_memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config, do_init=True) +# # append can ignore user_name +# local_memory_manager.append(message=message1) +# local_memory_manager.append(message=message2) +# local_memory_manager.append(message=message3) + + +# # test load from local +# print("###"*10 + "load local" + "###"*10) +# local_memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config, do_init=False) +# print(local_memory_manager.get_memory_pool("default").to_format_messages( +# content_key="content", format_type='str')) +# print(local_memory_manager.get_memory_pool("shuimo").to_format_messages( +# content_key="content", format_type='str')) +# print(local_memory_manager.get_memory_pool("shanshi").to_format_messages( +# content_key="content", format_type='str')) + + +# embedding retrieval test +print("###"*10 + "retrieval" + "###"*10) +local_memory_manager = TbaseMemoryManager(embed_config=embed_config, llm_config=model_config, tb_config=tb_config, do_init=False) +# text = "say hi to each other," +# # retrieval_type=datetime => retrieval from datetime and jieba +# print(local_memory_manager.router_retrieval( +# session_index="shanshi", text=text, datetime=getCurrentDatetime(), +# n=30, top_k=5, retrieval_type= "datetime")) +# # retrieval_type=eembedding => retrieval from embedding +# print(local_memory_manager.router_retrieval( +# session_index="shanshi", text=text, top_k=5, retrieval_type= "embedding")) +# # retrieval_type=text => retrieval from jieba +# print(local_memory_manager.router_retrieval( +# session_index="shanshi", text=text, top_k=5, retrieval_type= "text")) + +# # recursive_summary test +print("###"*10 + "recursive_summary" + "###"*10) +print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shanshi").messages, split_n=1, session_index="shanshi")) + +print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("shuimo").messages, split_n=1, session_index="shanshi")) + +print(local_memory_manager.recursive_summary(local_memory_manager.get_memory_pool("default").messages, split_n=1, session_index="shanshi")) diff --git a/tests/orm/table_test.py b/tests/orm/table_test.py index da9980a..ce7b1ca 100644 --- a/tests/orm/table_test.py +++ b/tests/orm/table_test.py @@ -2,7 +2,8 @@ from loguru import logger os.environ["do_create_dir"] = "1" -from muagent.orm import create_tables +# from muagent.orm import create_tables +from muagent.db_handler import create_tables # use to test, don't create some directory diff --git a/tests/prompt_manager/base_test.py b/tests/prompt_manager/base_test.py new file mode 100644 index 0000000..b828e91 --- /dev/null +++ b/tests/prompt_manager/base_test.py @@ -0,0 +1,116 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.schemas import Message, Memory +from muagent.prompt_manager import CommonPromptManager + + +system_prompt = """#### Agent Profile +As an agent specializing in software quality assurance, +your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet. +This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets. +Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates. +Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance. + +ATTENTION: response carefully referenced "Response Output Format" in format. + +Each test case should include: +1. clear description of the test purpose. +2. The input values or conditions for the test. +3. The expected outcome or assertion for the test. +4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case. +5. these test code should have package and import + +#### Input Format + +**Code Snippet:** the initial Code or objective that the user wanted to achieve + +**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet. + +#### Response Output Format +**SaveFileName:** construct a local file name based on Question and Context, such as + +```java +package/class.java +``` + +**Test Code:** generate the test code for the current Code Snippet. +```java +... +``` + +""" + +intput_template = "" +output_template = "" +prompt = "" + +agent_names = ["agent1", "agent2"] +agent_descs = [f"hello {agent}" for agent in agent_names] +tools = ["Multiplier", "WeatherInfo"] + +bpm = CommonPromptManager( + system_prompt=system_prompt, + input_template=intput_template, + output_template=output_template, + prompt=prompt, +) + +# +message1 = Message( + role_name="test", + role_type="user", + content="hello" +) +message2 = Message( + role_name="test", + role_type="assistant", + content="hi! can i help you!" +) +query = Message( + role_name="test", + role_type="user", + input_text="i want to know the weather of beijing", + content="i want to know the weather of beijing", + spec_parsed_content={ + "Retrieval Code Snippets": "hi" + }, + global_kwargs={ + "Code Snippet": "hello", + "Test Code": "nice to meet you." + } +) +memory = Memory(messages=[message1, message2]) + +# prompt = bpm.pre_print( +# query=query, memory=memory, tools=tools, +# agent_names=agent_names, agent_descs=agent_descs +# ) +# print(prompt) + +prompt = bpm.generate_prompt( + query=query, memory=memory, tools=tools, + agent_names=agent_names, agent_descs=agent_descs +) +print(prompt) \ No newline at end of file diff --git a/tests/prompt_manager/extend_common_pm_test.py b/tests/prompt_manager/extend_common_pm_test.py new file mode 100644 index 0000000..850f07b --- /dev/null +++ b/tests/prompt_manager/extend_common_pm_test.py @@ -0,0 +1,161 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.schemas import Memory, Message, PromptConfig +from muagent.prompt_manager import CommonPromptManager + + +from typing import ( + List, + Any, + Union, + Optional, + Dict, + Literal +) +from pydantic import BaseModel + +class NewPromptManager(CommonPromptManager): + + pm_type: str = "NewPromptManager" + """The type of prompt manager.""" + + def __init__( + self, + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = None, + language: Literal["en", "zh"] = "en", + *, + monitored_agents=[], + monitored_fields=[], + **kwargs + ): + # update new titles + extra_registry_titles: Dict = { + "EXAMPLE": { + "description": "这里是一些实例以供参考。", + "function": "", + "display_type": "title" + }, + "INPUT EXAMPLE": { + "description": "this input example", + "prompt": "", + "function": "", + "display_type": "description" + }, + "OUTPUT EXAMPLE": { + "description": "this output example", + "prompt": "", + "function": "handle_empty_key", + "display_type": "description" + }, + } + # + extra_register_edges: List = [ + ("EXAMPLE", "INPUT EXAMPLE"), + ("EXAMPLE", "OUTPUT EXAMPLE"), + ] + + # + new_dfsindex_to_str_format: Dict = { + 0: "#### {}\n{}", + 1: "### {}\n{}", + 2: "## {}\n{}", + 3: "# {}\n{}", + } + """use {title name} {description/function_value}""" + + super().__init__( + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template, + prompt=prompt, + language=language, + extra_registry_titles=extra_registry_titles, + extra_register_edges=extra_register_edges, + new_dfsindex_to_str_format=new_dfsindex_to_str_format, + monitored_agents=monitored_agents, + monitored_fields=monitored_fields, + **kwargs + ) + + +system_prompt = "you are a helpful assistant!\n" +intput_template = "" +output_template = "" +prompt = "" + +agent_names = ["agent1", "agent2"] +agent_descs = [f"hello {agent}" for agent in agent_names] +tools = ["Multiplier", "WeatherInfo"] + + +bpm = NewPromptManager( + # system_prompt=system_prompt, + # input_template=intput_template, + # output_template=output_template, + # prompt=prompt, + language="en", +) + + +# +message1 = Message( + role_name="test", + role_type="user", + content="hello" +) +message2 = Message( + role_name="test", + role_type="assistant", + content="hi! can i help you!" +) +query = Message( + role_name="test", + role_type="user", + input_text="i want to know the weather of beijing", + content="i want to know the weather of beijing", + spec_parsed_content={ + "Retrieval Code Snippets": "hi" + }, + global_kwargs={ + "Code Snippet": "hello", + "Test Code": "nice to meet you." + } +) +memory = Memory(messages=[message1, message2]) + +# prompt = bpm.pre_print( +# query=query, memory=memory, tools=tools, +# agent_names=agent_names, agent_descs=agent_descs +# ) +# print(prompt) + +prompt = bpm.generate_prompt( + query=query, memory=memory, tools=tools, + agent_names=agent_names, agent_descs=agent_descs +) +print(prompt) \ No newline at end of file diff --git a/tests/prompt_manager/new_pm_test.py b/tests/prompt_manager/new_pm_test.py new file mode 100644 index 0000000..6735fa1 --- /dev/null +++ b/tests/prompt_manager/new_pm_test.py @@ -0,0 +1,233 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.schemas import Memory, Message, PromptConfig +from muagent.prompt_manager import BasePromptManager +from muagent.agents import BaseAgent + + +from typing import ( + List, + Any, + Union, + Optional, + Dict, + Literal +) +from pydantic import BaseModel + +class NewPromptManager(BasePromptManager): + + pm_type: str = "NewPromptManager" + """The type of prompt manager.""" + + def __init__( + self, + system_prompt: str = "you are a helpful assistant!\n", + input_template: Union[str, BaseModel] = "", + output_template: Union[str, BaseModel] = "", + prompt: Optional[str] = None, + language: Literal["en", "zh"] = "en", + *, + monitored_agents=[], + monitored_fields=[], + **kwargs + ): + super().__init__( + system_prompt=system_prompt, + input_template=input_template, + output_template=output_template, + prompt=prompt, + language=language, + monitored_agents=monitored_agents, + monitored_fields=monitored_fields, + **kwargs + ) + # update new titles + self.extra_registry_titles: Dict = { + "AGENT PROFILE": { + "description": "", + "function": "handle_agent_profile", + "display_type": "title" + }, + "TOOL INFORMATION": { + "description": "", + "prompt": ( + 'Below is a list of tools that are available for your use:{formatted_tools}' + '\nvalid "tool_name" value is:\n{tool_names}' + ), + "function": "handle_tool_data", + "display_type": "description", + "str_template": "**{}\n{}" + }, + "AGENT INFORMATION": { + "description": "", + "prompt": ( + 'Please ensure your selection is one of the listed roles. Available roles for selection:\n{agents}' + 'Please ensure select the Role from agent names, such as {agent_names}' + ), + "function": "handle_agent_data", + "display_type": "description" + }, + } + # + self.extra_register_edges: List = [ + ("AGENT PROFILE", "AGENT INFORMATION"), + ("AGENT PROFILE", "TOOL INFORMATION"), + ] + + # + + self.new_dfsindex_to_str_format: Dict = { + 0: "#### {}\n{}", + 1: "### {}\n{}", + 2: "## {}\n{}", + 3: "# {}\n{}", + } + """use {title name} {description/function_value}""" + # + self.register_graph({}, [], {}, {}) + + def register_prompt(self): + """register input/output/prompt into titles and edges""" + pass + + def handle_agent_profile(self, **kwargs) -> str: + return self.system_prompt + + def handle_tool_data(self, **kwargs): + import random + from textwrap import dedent + from muagent.tools import get_tool, BaseToolModel + + if 'tools' not in kwargs: return "" + + tools: List = kwargs.get('tools') + prompt: str = kwargs.get('prompt') + tools: List[BaseToolModel] = [get_tool(tool) for tool in tools if isinstance(tool, str)] + + if len(tools) == 0: return "" + + tool_strings = [] + for tool in tools: + args_str = f'args: {str(tool.intput_to_json_schema())}' if tool.ToolInputArgs else "" + tool_strings.append(f"{tool.name}: {tool.description}, {args_str}") + formatted_tools = "\n".join(tool_strings) + + tool_names = ", ".join([tool.name for tool in tools]) + + tool_prompt = dedent(prompt.format(formatted_tools=formatted_tools, tool_names=tool_names)) + while "\n " in tool_prompt: + tool_prompt = tool_prompt.replace("\n ", "\n") + + return tool_prompt + + def handle_agent_data(self, **kwargs): + """""" + import random + from textwrap import dedent + if 'agent_names' not in kwargs or "agent_descs" not in kwargs: + return "" + + agent_names: List = kwargs.get('agent_names') + agent_descs: List = kwargs.get('agent_descs') + prompt: str = kwargs.get('prompt') + + if len(agent_names) == 0: return "" + + random.shuffle(agent_names) + agent_descriptions = [] + for agent_name, desc in zip(agent_names, agent_descs): + while "\n\n" in desc: + desc = desc.replace("\n\n", "\n") + desc = desc.replace("\n", ",") + agent_descriptions.append( + f'"role name: {agent_name}\nrole description: {desc}"' + ) + + agent_description = "\n".join(agent_descriptions) + agent_prompt = dedent( + prompt.format(agents=agent_description, agent_names=agent_names) + ) + + while "\n " in agent_prompt: + agent_prompt = agent_prompt.replace("\n ", "\n") + + return agent_prompt + +system_prompt = "you are a helpful assistant!\n" +intput_template = "" +output_template = "" +prompt = "" + +agent_names = ["agent1", "agent2"] +agent_descs = [f"hello {agent}" for agent in agent_names] +tools = ["Multiplier", "WeatherInfo"] + + +bpm = NewPromptManager( + # system_prompt=system_prompt, + # input_template=intput_template, + # output_template=output_template, + # prompt=prompt, + language="zh", +) + +# +message1 = Message( + role_name="test", + role_type="user", + content="hello" +) +message2 = Message( + role_name="test", + role_type="assistant", + content="hi! can i help you!" +) +query = Message( + role_name="test", + role_type="user", + input_text="i want to know the weather of beijing", + content="i want to know the weather of beijing", + spec_parsed_content={ + "Retrieval Code Snippets": "hi" + }, + global_kwargs={ + "Code Snippet": "hello", + "Test Code": "nice to meet you." + } +) +memory = Memory(messages=[message1, message2]) + +prompt = bpm.pre_print( + query=query, memory=memory, tools=tools, + agent_names=agent_names, agent_descs=agent_descs +) +print(prompt) + +prompt = bpm.generate_prompt( + query=query, memory=memory, tools=tools, + agent_names=agent_names, agent_descs=agent_descs +) +print(prompt) \ No newline at end of file diff --git a/tests/retrieval/faiss_test.py b/tests/retrieval/faiss_test.py new file mode 100644 index 0000000..101822f --- /dev/null +++ b/tests/retrieval/faiss_test.py @@ -0,0 +1,69 @@ +import os, sys +from loguru import logger +import json + +os.environ["do_create_dir"] = "1" + +try: + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config + api_key = os.environ["OPENAI_API_KEY"] + api_base_url= os.environ["API_BASE_URL"] + model_name = os.environ["model_name"] + model_engine = os.environ["model_engine"] + embed_model = os.environ["embed_model"] + embed_model_path = os.environ["embed_model_path"] +except Exception as e: + # set your config + api_key = "" + api_base_url= "" + model_name = "" + model_engine = os.environ["model_engine"] + embed_model = "" + embed_model_path = "" + logger.error(f"{e}") + +# test local code +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + + +from muagent.db_handler import LocalFaissHandler +from muagent.schemas.db import DBConfig, GBConfig, VBConfig, TBConfig +from muagent.llm_models.llm_config import EmbedConfig, LLMConfig +from muagent.models import ModelConfig +import numpy as np + +llm_config = LLMConfig( + model_name=model_name, model_engine=model_engine, api_key=api_key, api_base_url=api_base_url, temperature=0.3, +) + +model_configs = json.loads(os.environ["MODEL_CONFIGS"]) +model_type = "ollama_embedding" +model_config = model_configs[model_type] + +embed_config = ModelConfig( + config_name="model_test", + model_type=model_type, + model_name=model_config["model_name"], + api_key=model_config["api_key"], +) +# +import random +embedding = [random.random() for _ in range(768)] +print(len(embedding), np.mean(embedding)) + + +vb_config = VBConfig(vb_type="LocalFaissHandler") +vb = LocalFaissHandler(embed_config, vb_config) + +vb.create_vs("shanshi") +vector = np.array([embedding], dtype=np.float32) +scores, indices = vb.search_index.index.search(vector, 20) +print(scores) diff --git a/tests/sandbox/nbclient_test.py b/tests/sandbox/nbclient_test.py new file mode 100644 index 0000000..5367b25 --- /dev/null +++ b/tests/sandbox/nbclient_test.py @@ -0,0 +1,218 @@ +# # -*- coding: utf-8 -*- +# # pylint: disable=C0301 +# """Service for executing jupyter notebooks interactively +# Partially referenced the implementation of +# https://github.com/geekan/MetaGPT/blob/main/metagpt/actions/di/execute_nb_code.py +# """ +# import base64 +# import asyncio +# from loguru import logger + +# try: +# import nbclient +# import nbformat +# except ImportError: +# nbclient = None +# nbformat = None + + +# class NoteBookExecutor: +# """ +# Class for executing jupyter notebooks block interactively. +# To use the service function, you should first init the class, then call the +# run_code_on_notebook function. + +# Example: + +# ```ipython +# from agentscope.service.service_toolkit import * +# from agentscope.service.execute_code.exec_notebook import * +# nbe = NoteBookExecutor() +# code = "print('helloworld')" +# # calling directly +# nbe.run_code_on_notebook(code) + +# >>> Executing function run_code_on_notebook with arguments: +# >>> code: print('helloworld') +# >>> END + +# # calling with service toolkit +# service_toolkit = ServiceToolkit() +# service_toolkit.add(nbe.run_code_on_notebook) +# input_obs = [{"name": "run_code_on_notebook", "arguments":{"code": code}}] +# res_of_string_input = service_toolkit.parse_and_call_func(input_obs) + +# "1. Execute function run_code_on_notebook\n [ARGUMENTS]:\n code: print('helloworld')\n [STATUS]: SUCCESS\n [RESULT]: ['helloworld\\n']\n" + +# ``` +# """ # noqa + +# def __init__( +# self, +# timeout: int = 300, +# ) -> None: +# """ +# The construct function of the NoteBookExecutor. +# Args: +# timeout (Optional`int`): +# The timeout for each cell execution. +# Default to 300. +# """ + +# if nbclient is None or nbformat is None: +# raise ImportError( +# "The package nbclient or nbformat is not found. Please " +# "install it by `pip install notebook nbclient nbformat`", +# ) + +# self.nb = nbformat.v4.new_notebook() +# self.nb_client = nbclient.NotebookClient(nb=self.nb) +# self.timeout = timeout + +# asyncio.run(self._start_client()) + +# def _output_parser(self, output: dict) -> str: +# """Parse the output of the notebook cell and return str""" +# if output["output_type"] == "stream": +# return output["text"] +# elif output["output_type"] == "execute_result": +# return output["data"]["text/plain"] +# elif output["output_type"] == "display_data": +# if "image/png" in output["data"]: +# file_path = self._save_image(output["data"]["image/png"]) +# return f"Displayed image saved to {file_path}" +# else: +# return "Unsupported display type" +# elif output["output_type"] == "error": +# return output["traceback"] +# else: +# logger.info(f"Unsupported output encountered: {output}") +# return "Unsupported output encountered" + +# async def _start_client(self) -> None: +# """start notebook client""" +# if self.nb_client.kc is None or not await self.nb_client.kc.is_alive(): +# self.nb_client.create_kernel_manager() +# self.nb_client.start_new_kernel() +# self.nb_client.start_new_kernel_client() + +# async def _kill_client(self) -> None: +# """kill notebook client""" +# if ( +# self.nb_client.km is not None +# and await self.nb_client.km.is_alive() +# ): +# await self.nb_client.km.shutdown_kernel(now=True) +# await self.nb_client.km.cleanup_resources() + +# self.nb_client.kc.stop_channels() +# self.nb_client.kc = None +# self.nb_client.km = None + +# async def _restart_client(self) -> None: +# """Restart the notebook client""" +# await self._kill_client() +# self.nb_client = nbclient.NotebookClient(self.nb, timeout=self.timeout) +# await self._start_client() + +# async def _run_cell(self, cell_index: int): +# """Run a cell in the notebook by its index""" +# try: +# self.nb_client.execute_cell(self.nb.cells[cell_index], cell_index) +# return [self._output_parser(output) for output in self.nb.cells[cell_index].outputs] +# except nbclient.exceptions.DeadKernelError: +# await self.reset_notebook() +# return "DeadKernelError when executing cell, reset kernel" +# except nbclient.exceptions.CellTimeoutError: +# assert self.nb_client.km is not None +# await self.nb_client.km.interrupt_kernel() +# return ( +# "CellTimeoutError when executing cell" +# ", code execution timeout" +# ) +# except Exception as e: +# return str(e) + +# @property +# def cells_length(self) -> int: +# """return cell length""" +# return len(self.nb.cells) + +# async def async_run_code_on_notebook(self, code: str): +# """ +# Run the code on interactive notebook +# """ +# self.nb.cells.append(nbformat.v4.new_code_cell(code)) +# cell_index = self.cells_length - 1 +# return await self._run_cell(cell_index) + +# def run_code_on_notebook(self, code: str): +# """ +# Run the code on interactive jupyter notebook. + +# Args: +# code (`str`): +# The Python code to be executed in the interactive notebook. + +# Returns: +# `ServiceResponse`: whether the code execution was successful, +# and the output of the code execution. +# """ +# return asyncio.run(self.async_run_code_on_notebook(code)) + +# def reset_notebook(self) -> str: +# """ +# Reset the notebook +# """ +# asyncio.run(self._restart_client()) +# return "Reset notebook" + + +import os +from loguru import logger + +try: + import os, sys + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + + +from muagent.sandbox import NBClientBox, NoteBookExecutor + +nbe = NoteBookExecutor() +code = f""" +x = 1 +y = 1 +z = x+y +print(z) +""" +print(nbe.run_code_on_notebook(code)) + + +code = f"""z +""" +print(nbe.run_code_on_notebook(code)) + + +codebox = NBClientBox() + +reuslt = codebox.chat("```import os\nos.getcwd()```", do_code_exe=True) +print(reuslt) + +reuslt = codebox.chat("```print('hello world!')```", do_code_exe=True) +print(reuslt) + +with NBClientBox(do_code_exe=True) as codebox: + result = codebox.run("'hello world!'") + print(result) \ No newline at end of file diff --git a/tests/service/ekg_project_test.py b/tests/service/ekg_project_test.py new file mode 100644 index 0000000..b829da4 --- /dev/null +++ b/tests/service/ekg_project_test.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- +from loguru import logger +import os, sys +import json + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) +try: + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +import logging +# Set the logging level to WARNING, which will suppress INFO and DEBUG messages +logging.basicConfig(level=logging.ERROR) + + +from muagent import EKG, get_ekg_project_config_from_env +# nodes = [{'id': 'haPvrjEkz4LARZyR7OAuPmVMHMIQPMew', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '需要公司多人参与的事务,以及相关的问题', 'name': '公司事务'}}, +# {'id': 'dicVRAk5rT3y9LxcmBCN2jDi1TjHc5rm', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '与个人有关的事务(如个人贷款),或遇到的个人问题,不涉及公司事务', +# 'name': '个人事务'}}, +# {'id': 'ClKvwjBRZUJC7ttSZaiT0dh7lhSujNWi', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '公司活动', 'name': '公司活动'}}, +# {'id': 'NyBXAHQckQx1xL5lnSgBGlotbZkkQ9C7', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '金融(如借款、存款、贷款等)', 'name': '金融'}}, +# {'id': '6sa4zJCnVKJxKMtOtypapjZk4sdo93QU', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '医疗(包括预约、挂号、看病、诊断等)', 'name': '医疗'}}, +# {'id': 'a8d85669_141a_4f54_ab8c_209c08d27c35', +# 'type': 'opsgptkg_schedule', +# 'attributes': {'description': '组织一次公司活动', +# 'name': '组织一次公司活动', +# 'enable': 'False'}}, +# {'id': '2b8df337_f29e_4d49_865f_84088c3a94e7', +# 'type': 'opsgptkg_schedule', +# 'attributes': {'description': '在线申请贷款', +# 'name': '在线申请贷款', +# 'enable': 'False'}}, +# {'id': 'b9fe38f1_33f6_468b_a1dd_43efdfd8e2d1', +# 'type': 'opsgptkg_schedule', +# 'attributes': {'description': '预约医生', 'name': '预约医生', 'enable': 'False'}}, +# {'id': '98234102_4e4a_4997_9b1e_3cda6382b1c7', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '确定活动主题:确定活动的主要目的(如团建、庆祝活动等)', +# 'name': '确定活动主题:确定活动的主要目的(如团建、庆祝活动等)'}}, +# {'id': '59030678_760d_4a10_8d61_0d4e4cc5fbcb', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '访问贷款平台:输入网址并访问贷款申请网站', +# 'name': '访问贷款平台:输入网址并访问贷款申请网站'}}, +# {'id': '5afab73b_8f03_422f_856e_386f183bdd71', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '选择医院/医生:访问医院官网或APP,查找相关科室和医生', +# 'name': '选择医院/医生:访问医院官网或APP,查找相关科室和医生'}}, +# {'id': '95ec00ef_cc9c_4947_a21c_88eeb9a71af5', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '选择活动类型', 'name': '选择活动类型'}}, +# {'id': '5504af87_416e_4ee5_bfce_86b969a63433', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '注册/登录:如果你已经注册,输入用户名和密码进行登录。如果你还没有注册,点击“注册”按钮,填写个人信息,创建账户', +# 'name': '注册/登录:如果你已经注册,输入用户名和密码进行登录。如果你还没有注册,点击“注册”按钮,填写个人信息,创建账户'}}, +# {'id': '3ff8f54a_fa65_4368_86ce_d65058035dd0', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '查看可预约时间:点击医生姓名,查看可预约时段', +# 'name': '查看可预约时间:点击医生姓名,查看可预约时段'}}, +# {'id': 'd5e760b4_ae82_410d_a73d_4c0c98926ae5', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '室内活动', 'name': '室内活动'}}, +# {'id': '2a37b90a_fd96_4548_989c_7c1e8fa9d881', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '户外活动', 'name': '户外活动'}}, +# {'id': '88d4cf2b_7cf5_4e40_b54e_59268f119f63', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '选择贷款类型:浏览可用的贷款类型(如个人贷款、汽车贷款、房屋贷款),选择适合自己的贷款类型', +# 'name': '选择贷款类型:浏览可用的贷款类型(如个人贷款、汽车贷款、房屋贷款),选择适合自己的贷款类型'}}, +# {'id': '39021995_6e63_4907_9d67_26ba50d0cd44', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '填写个人信息:输入姓名、联系方式等,选择预约时间', +# 'name': '填写个人信息:输入姓名、联系方式等,选择预约时间'}}, +# {'id': '59fe9c1d_0731_403e_936a_2e2bbba4b3ee', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '选择具体的室内活动(如会议、晚会、游戏),确定场地和时间,准备相关的设备(如投影仪、音响),安排餐饮和娱乐节目,发出邀请通知', +# 'name': '选择具体的室内活动(如会议、晚会、游戏),确定场地和时间,准备相关的设备(如投影仪、音响),安排餐饮和娱乐节目,发出邀请通知'}}, +# {'id': '60163dc6_87af_4972_b350_6b9275975c83', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '选择具体的户外活动(如远足、烧烤、运动会),确定地点和时间,安排交通工具和安全措施,联系供应商(如餐饮、设备租赁),发出邀请通知', +# 'name': '选择具体的户外活动(如远足、烧烤、运动会),确定地点和时间,安排交通工具和安全措施,联系供应商(如餐饮、设备租赁),发出邀请通知'}}, +# {'id': '910f3634_b999_4cf3_94c9_346a67b0d5ed', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '填写申请表:提供个人信息(如姓名、年龄、收入等),提供贷款金额和贷款目的', +# 'name': '填写申请表:提供个人信息(如姓名、年龄、收入等),提供贷款金额和贷款目的'}}, +# {'id': '1330ad69_dfc3_4538_864e_6867a3fd8dd4', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '确认预约:检查预约信息,点击“确认预约”按钮', +# 'name': '确认预约:检查预约信息,点击“确认预约”按钮'}}, +# {'id': 'fcbc3e04_ad8c_4aad_9f75_191f8037ced8', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '预算审核:计算活动预估费用,提交预算给管理层审核', +# 'name': '预算审核:计算活动预估费用,提交预算给管理层审核'}}, +# {'id': '2c7a0d7b_a490_41b9_a6f8_e71b5212e0be', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '提交资料:上传所需文件(如身份证、收入证明等)', +# 'name': '提交资料:上传所需文件(如身份证、收入证明等)'}}, +# {'id': '3cd46fb7_e11c_4181_8670_2f080a453142', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '接收通知:收到预约确认短信或邮件', +# 'name': '接收通知:收到预约确认短信或邮件'}}, +# {'id': '0f4610cd_cf6a_475b_8ac0_80166569a292', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '审核资料:系统开始审核申请', 'name': '审核资料:系统开始审核申请'}}, +# {'id': 'b9f81925_b43a_459d_9902_1bc4b024f5a1', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '审核通过', 'name': '审核通过'}}, +# {'id': '191687cd_1b76_4e77_9f2a_e67936dd372e', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '审核失败', 'name': '审核失败'}}, +# {'id': '18c33ec1_08ef_4df8_b938_7244852d19c8', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '用户收到“申请通过”的通知,前往下一步选择贷款期限和还款方式', +# 'name': '用户收到“申请通过”的通知,前往下一步选择贷款期限和还款方式'}}, +# {'id': 'b73c2551_0890_40fb_b0ca_04912bc21b65', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '提供反馈,建议修改后重新申请', 'name': '提供反馈,建议修改后重新申请'}}, +# {'id': 'e95adaa2_d177_435b_bac7_a8b6047ecc3d', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '确认贷款条件:查看贷款条款和条件', +# 'name': '确认贷款条件:查看贷款条款和条件'}}, +# {'id': '0c561d68_ee31_49d2_82c1_1dac81e731ff', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '拒绝条款', 'name': '拒绝条款'}}, +# {'id': '81f579ac_851d_4b85_8608_d2732a2612ff', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '接受条款', 'name': '接受条款'}}, +# {'id': '1f0b64aa_5d45_4cf5_bcdd_084b8c125889', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '选择“拒绝”并退出申请流程', 'name': '选择“拒绝”并退出申请流程'}}, +# {'id': '5fd5901a_8adc_4b76_aea2_dcf18884ea0e', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '点击“接受”并继续', 'name': '点击“接受”并继续'}}, +# {'id': '8c999c60_baa7_4e74_903b_f10f148dd12f', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '签署合同:在线签署贷款合同', 'name': '签署合同:在线签署贷款合同'}}, +# {'id': 'e1004c60_5c0c_4f32_b765_a57cc4d39dcc', +# 'type': 'opsgptkg_analysis', +# 'attributes': {'summaryswitch': 'False', +# 'description': '根据提示前往医院就诊', +# 'name': '根据提示前往医院就诊'}}, +# {'id': 'c50ff5e3_aa01_4a6c_96d7_d8645303846d', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '活动宣传:制作宣传材料(如海报、邮件通知),在公司内部推广活动信息', +# 'name': '活动宣传:制作宣传材料(如海报、邮件通知),在公司内部推广活动信息'}}, +# {'id': '4f540a57_f73d_451e_aafb_43f1335a18a7', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '活动实施:根据选择的活动类型,执行相关安排,进行现场协调(无论是户外还是室内)', +# 'name': '活动实施:根据选择的活动类型,执行相关安排,进行现场协调(无论是户外还是室内)'}}, +# {'id': 'c9952fa7_7f82_4737_8cfd_bdbb2dabb20e', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '活动反馈:收集参与者的反馈意见,总结活动的成功之处和改进建议', +# 'name': '活动反馈:收集参与者的反馈意见,总结活动的成功之处和改进建议'}}, +# {'id': 'ekg_team_default', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '团队起始节点', 'name': '开始'}}] + +# nodes = [{'id': '剧本杀/谁是卧底', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '谁是卧底', 'name': '谁是卧底', 'extra': ''}}, +# {'id': '剧本杀/狼人杀', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '狼人杀', 'name': '狼人杀', 'extra': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互', +# 'type': 'opsgptkg_schedule', +# 'attributes': {'extra': '', +# 'description': '智能交互', +# 'name': '智能交互', +# 'enable': True}}, +# {'id': '剧本杀/狼人杀/智能交互', +# 'type': 'opsgptkg_schedule', +# 'attributes': {'extra': '', +# 'description': '智能交互', +# 'name': '智能交互', +# 'enable': False}}, +# {'id': '剧本杀/谁是卧底/智能交互/分配座位', +# 'type': 'opsgptkg_task', +# 'attributes': {'extra': '{"dodisplay":"True"}', +# 'executetype': '', +# 'description': '分配座位', +# 'name': '分配座位', +# 'accesscriteria': ''}}, +# {'id': '剧本杀/狼人杀/智能交互/位置选择', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '位置选择', +# 'name': '位置选择', +# 'accesscriteria': '', +# 'extra': '{"memory_tag": "all"}', +# 'executetype': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互/角色分配和单词分配', +# 'type': 'opsgptkg_task', +# 'attributes': {'accesscriteria': '', +# 'extra': '{"memory_tag": "None","dodisplay":"True"}', +# 'executetype': '', +# 'description': '角色分配和单词分配', +# 'name': '角色分配和单词分配'}}, +# {'id': '剧本杀/狼人杀/智能交互/角色选择', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '角色选择', +# 'name': '角色选择', +# 'accesscriteria': '', +# 'extra': '{"memory_tag": "None"}', +# 'executetype': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互/通知身份', +# 'type': 'opsgptkg_task', +# 'attributes': {'extra': '{"pattern": "react","dodisplay":"True"}', +# 'executetype': '', +# 'description': '##角色##\n你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。\n目前已经完成 1)位置分配; 2)角色分配和单词分配。\n##任务##\n向所有玩家通知信息他们的 座位信息和单词信息。\n发送格式是: 【身份通知】你是{player_name}, 你的位置是{位置号}号, 你分配的单词是{单词}\n##详细步骤##\nstep1.依次向所有玩家通知信息他们的 座位信息和单词信息。发送格式是: 你是{player_name}, 你的位置是{位置号}号, 你分配的单词是{单词}\nstpe2.所有玩家信息都发送后,结束\n\n##注意##\n1. 每条信息只能发送给对应的玩家,其他人无法看到。\n2. 不要告诉玩家的角色信息,即不要高斯他是平民还是卧底角色\n3. 在将每个人的信息通知到后,本阶段任务结束\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n[{"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}, ...]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n#example#\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}', +# 'name': '通知身份', +# 'accesscriteria': ''}}, +# {'id': '剧本杀/狼人杀/智能交互/向玩家通知消息', +# 'type': 'opsgptkg_task', +# 'attributes': {'extra': '{"pattern": "react"}', +# 'executetype': '', +# 'description': '##角色##\n你正在参与狼人杀这个游戏,你的角色是[主持人]。你熟悉狼人杀游戏的完整流程,你需要完成[任务],保证狼人杀游戏的顺利进行。\n目前已经完成位置分配和角色分配。\n##任务##\n向所有玩家通知信息他们的座位信息和角色信息。\n发送格式是: 你是{player_name}, 你的位置是{位置号}号,你的身份是{角色名}\n##注意##\n1. 每条信息只能发送给对应的玩家,其他人无法看到。\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n\n[{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果##', +# 'name': '向玩家通知消息', +# 'accesscriteria': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互/关键信息_1', +# 'type': 'opsgptkg_task', +# 'attributes': {'executetype': '', +# 'description': '关键信息', +# 'name': '关键信息', +# 'accesscriteria': '', +# 'extra': '{"ignorememory":"True","dodisplay":"True"}'}}, +# {'id': '剧本杀/狼人杀/智能交互/狼人时刻', +# 'type': 'opsgptkg_task', +# 'attributes': {'accesscriteria': 'OR', +# 'extra': '{"pattern": "react"}', +# 'executetype': '', +# 'description': '##背景##\n在狼人杀游戏中,主持人通知当前存活的狼人玩家指认一位击杀对象,所有狼人玩家给出击杀目标,主持人确定最终结果。\n\n##任务##\n整个流程分为6个步骤:\n1. 存活狼人通知:主持人向所有的狼人玩家广播,告知他们当前存活的狼人玩家有哪些。\n2. 第一轮讨论:主持人告知所有存活的狼人玩家投票,从当前存活的非狼人玩家中,挑选一个想要击杀的玩家。\n3. 第一轮投票:按照座位顺序,每一位存活的狼人为自己想要击杀的玩家投票。\n4. 第一轮结果反馈:主持人统计所有狼人的票数分布,确定他们是否达成一致。若达成一致,告知所有狼人最终被击杀的玩家的player_name,流程结束;否则,告知他们票数的分布情况,并让所有狼人重新投票指定击杀目标,主持人需要提醒他们,若该轮还不能达成一致,则取票数最大的目标为最终击杀对象。\n5. 第二轮投票:按照座位顺序,每一位存活的狼人为自己想要击杀的玩家投票。\n6. 第二轮结果反馈:主持人统计第二轮投票中所有狼人的票数分布,取票数最大的玩家为最终击杀对象,如果存在至少两个对象的票数最大且相同,取座位号最大的作为最终击杀对象。主持人告知所有狼人玩家最终被击杀的玩家的player_name。\n\n该任务的参与者只有狼人玩家和主持人,信息可见对象是所有狼人玩家。\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n\n[{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果##', +# 'name': '狼人时刻'}}, +# {'id': '剧本杀/谁是卧底/智能交互/开始新一轮的讨论', +# 'type': 'opsgptkg_task', +# 'attributes': {'accesscriteria': 'OR', +# 'extra': '{"pattern": "react", "endcheck": "True",\n"memory_tag":"all",\n"dodisplay":"True"}', +# 'executetype': '', +# 'description': '###以上为本局游戏记录###\n\n\n##背景##\n你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。\n\n##任务##\n以结构化的语句来模拟进行 谁是卧底的讨论环节。 在这一个环节里,所有主持人先宣布目前存活的玩家,然后每位玩家按照座位顺序发言\n\n\n##详细步骤##\nstep1. 主持人根据本局游戏历史记录,感知最开始所有的玩家 以及 在前面轮数中已经被票选死亡的玩家。注意死亡的玩家不能参与本轮游戏。得到当前存活的玩家个数以及其player_name。 并告知所有玩家当前存活的玩家个数以及其player_name。\nstep2. 主持人确定发言规则并告知所有玩家,发言规则步骤如下: 存活的玩家按照座位顺序由小到大进行发言\n(一个例子:假设总共有5个玩家,如果3号位置处玩家死亡,则发言顺序为:1_>2_>4_>5)\nstep3. 存活的的玩家按照顺序依次发言\nstpe4. 在每一位存活的玩家都发言后,结束\n\n \n \n##注意##\n1.之前的游戏轮数可能已经投票选中了某位/某些玩家,被票选中的玩家会立即死亡,不再视为存活玩家,死亡的玩家不能参与本轮游戏 \n2.你要让所有存活玩家都参与发言,不能遗漏任何存活玩家。在本轮所有玩家只发言一次\n3.该任务的参与者为主持人和所有存活的玩家,信息可见对象为所有玩家。\n4.不仅要模拟主持人的发言,还需要模拟玩家的发言\n5.每一位存活的玩家均发完言后,本阶段结束\n\n\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n[ {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}] }, ...]\n\n\n\n\n关键词含义如下:\n_ thought (str): 主持人执行行动的一些思考,包括分析玩家的存活状态,对历史对话信息的理解,对当前任务情况的判断等。 \n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空 ;否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为本条信息的可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"thought": "str", "action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"thought": "str", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}\n6. 如果是人类玩家发言, 一定要选择类似 agent_人类玩家 这样的agent_name', +# 'name': '开始新一轮的讨论'}}, +# {'id': '剧本杀/狼人杀/智能交互/天亮讨论', +# 'type': 'opsgptkg_task', +# 'attributes': {'executetype': '', +# 'description': '##角色##\n你正在参与狼人杀这个游戏,你的角色是[主持人]。你熟悉狼人杀游戏的完整流程,你需要完成[任务],保证狼人杀游戏的顺利进行。\n##任务##\n你的任务如下: \n1. 告诉玩家昨晚发生的情况: 首先告诉玩家天亮了,然后你需要根据过往信息,告诉所有玩家,昨晚是否有玩家死亡。如果有,则向所有人宣布死亡玩家的名字,你只能宣布死亡玩家是谁如:"昨晚xx玩家死了",不要透露任何其他信息。如果没有,则宣布昨晚是平安夜。\n2. 确定发言规则并告诉所有玩家:\n确定发言规则步骤如下: \n第一步:确定第一个发言玩家,第一个发言的玩家为死者的座位号加1位置处的玩家(注意:最后一个位置+1的位置号为1号座位),如无人死亡,则从1号玩家开始。\n第二步:告诉所有玩家从第一个发言玩家开始发言,除了死亡玩家,每个人都需要按座位号依次讨论,只讨论一轮,所有人发言完毕后结束。注意不能遗忘指挥任何存活玩家发言!\n以下是一个例子:\n```\n总共有5个玩家,如果3号位置处玩家死亡,则第一个发言玩家为4号位置处玩家,因此从他开始发言,发言顺序为:4_>5_>1_>2\n```\n3. 依次指定存活玩家依次发言\n4. 被指定的玩家依次发言\n##注意##\n1. 你必须根据规则确定第一个发言玩家是谁,然后根据第一个发言玩家的座位号,确定所有人的发言顺序并将具体发言顺序并告知所有玩家,不要做任何多余解释\n2. 你要让所有存活玩家都参与发言,不能遗漏任何存活玩家\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n\n[{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}]\n\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果(请直接在后面输出,如果后面已经有部分结果,请续写。一定要保持续写后的内容结合前者能构成一个合法的 jsonstr)##', +# 'name': '天亮讨论', +# 'accesscriteria': '', +# 'extra': '{"pattern": "react"}'}}, +# {'id': '剧本杀/谁是卧底/智能交互/关键信息_2', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '关键信息', +# 'name': '关键信息', +# 'accesscriteria': '', +# 'extra': '{"ignorememory":"True","dodisplay":"True"}', +# 'executetype': ''}}, +# {'id': '剧本杀/狼人杀/智能交互/票选凶手', +# 'type': 'opsgptkg_task', +# 'attributes': {'accesscriteria': '', +# 'extra': '{"pattern": "react"}', +# 'executetype': '', +# 'description': '##角色##\n你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。\n\n##任务##\n你的任务如下:\n1. 告诉玩家投票规则,规则步骤如下: \nstep1: 确定讨论阶段第一个发言的玩家A\nstep2: 从A玩家开始,按座位号依次投票,每个玩家只能对一个玩家进行投票,投票这个玩家表示认为该玩家是“卧底”。每个玩家只能投一次票。\nstep3: 将完整投票规则告诉所有玩家\n2. 指挥存活玩家依次投票。\n3. 被指定的玩家进行投票\n4. 主持人统计投票结果,并告知所有玩家,投出的角色是谁。\n\n该任务的参与者为主持人和所有存活的玩家,信息可见对象是所有玩家。\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为Python可解析的JSON,格式为\n```\n{"action": {player_name, agent_name}, "observation" or "Dungeon_Master": [{content, memory_tag}, ...]}\n```\n关键词含义如下:\n_ player_name (str): 行动方的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): 行动方的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n\n##example##\n如果是玩家发言,则用 {"action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{"content": "str", "memory_tag":["agent_name_a","agent_name_b"]}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{"content": "str", memory_tag:["agent_name_a","agent_name_b"]}]}\n\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请直接输出jsonstr,不用输出markdown格式\n\n##结果##\n', +# 'name': '票选凶手'}}, +# {'id': '剧本杀/谁是卧底/智能交互/票选卧底_1', +# 'type': 'opsgptkg_task', +# 'attributes': {'executetype': '', +# 'description': '##以上为本局游戏历史记录##\n##角色##\n你是一个统计票数大师,你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。 现在是投票阶段。\n\n##任务##\n以结构化的语句来模拟进行 谁是卧底的投票环节, 也仅仅只模拟投票环节,投票环节结束后就本阶段就停止了,由后续的阶段继续进行游戏。 在这一个环节里,由主持人先告知大家投票规则,然后组织每位存活玩家按照座位顺序发言投票, 所有人投票后,本阶段结束。 \n##详细步骤##\n你的任务如下:\nstep1. 向所有玩家通知现在进入了票选环节,在这个环节,每个人都一定要投票指定某一个玩家为卧底\nstep2. 主持人确定投票顺序并告知所有玩家。 投票顺序基于如下规则: 1: 存活的玩家按照座位顺序由小到大进行投票(一个例子:假设总共有5个玩家,如果3号位置处玩家死亡,则投票顺序为:1_>2_>4_>5)2: 按座位号依次投票,每个玩家只能对一个玩家进行投票。每个玩家只能投一次票。3:票数最多的玩家会立即死亡\n\nstep3. 存活的的玩家按照顺序进行投票\nstep4. 所有存活玩家发言完毕,主持人宣布投票环节结束\n该任务的参与者为主持人和所有存活的玩家,信息可见对象是所有玩家。\n##注意##\n\n1.之前的游戏轮数可能已经投票选中了某位/某些玩家,被票选中的玩家会立即死亡,不再视为存活玩家 \n2.你要让所有存活玩家都参与投票,不能遗漏任何存活玩家。在本轮每一位玩家只投票一个人\n3.该任务的参与者为主持人和所有存活的玩家,信息可见对象为所有玩家。\n4.不仅要模拟主持人的发言,还需要模拟玩家的发言\n5.不允许玩家自己投自己,如果出现了这种情况,主持人会提醒玩家重新投票。\n\n\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n["thought": str, {"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}, ...]\n关键词含义如下:\n_ thought (str): 主持人执行行动的一些思考,包括分析玩家的存活状态,对历史对话信息的理解,对当前任务情况的判断。 \n_ player_name (str): ***的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): ***的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n##example##\n如果是玩家发言,则用 {"thought": "str", "action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"thought": "str", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}\n6. 如果是人类玩家发言, 一定要选择类似 人类agent 这样的agent_name', +# 'name': '票选卧底', +# 'accesscriteria': '', +# 'extra': '{"pattern": "react", "endcheck": "True", "memory_tag":"all","dodisplay":"True"}'}}, +# {'id': '剧本杀/谁是卧底/智能交互/关键信息_4', +# 'type': 'opsgptkg_task', +# 'attributes': {'extra': '{"ignorememory":"True","dodisplay":"True"}', +# 'executetype': '', +# 'description': '关键信息_4', +# 'name': '关键信息_4', +# 'accesscriteria': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互/统计票数', +# 'type': 'opsgptkg_task', +# 'attributes': {'executetype': '', +# 'description': '##以上为本局游戏历史记录##\n##角色##\n你是一个统计票数大师,你非常擅长计数以及统计信息。你正在参与“谁是卧底”这个游戏,你的角色是[主持人]。你熟悉“谁是卧底”游戏的完整流程,你需要完成[任务],保证游戏的顺利进行。 现在是票数统计阶段\n\n##任务##\n以结构化的语句来模拟进行 谁是卧底的票数统计阶段, 也仅仅只票数统计阶段环节,票数统计阶段结束后就本阶段就停止了,由后续的阶段继续进行游戏。 在这一个环节里,由主持人根据上一轮存活的玩家投票结果统计票数。 \n##详细步骤##\n你的任务如下:\nstep1. 主持人感知上一轮投票环节每位玩家的发言, 统计投票结果,格式为[{"player_name":票数}]. \nstep2 然后,主持人宣布死亡的玩家,以最大票数为本轮被投票的目标,如果票数相同,则取座位号高的角色死亡。并告知所有玩家本轮被投票玩家的player_name。(格式为【重要通知】本轮死亡的玩家为XXX)同时向所有玩家宣布,被投票中的角色会视为立即死亡(即不再视为存活角色)\nstep3. 在宣布死亡玩家后,本阶段流程结束,由后续阶段继续推进游戏\n该任务的参与者为主持人和所有存活的玩家,信息可见对象是所有玩家。\n##注意##\n1.如果有2个或者两个以上的被玩家被投的票数相同,则取座位号高的玩家死亡。并告知大家原因:票数相同,取座位号高的玩家死亡\n2.在统计票数时,首先确认存活玩家的数量,再先仔细回忆,谁被投了。 最后统计每位玩家被投的次数。 由于每位玩家只有一票,所以被投次数的总和等于存活玩家的数量 \n3.通知完死亡玩家是谁后,本阶段才结束,由后续阶段继续推进游戏。输出 {"action": "taskend"}即可\n4.主持人只有当通知本轮死亡的玩家时,才使用【重要通知】的前缀,其他情况下不要使用【重要通知】前缀\n5.只统计上一轮投票环节的情况\n##example##\n{"thought": "在上一轮中, 存活玩家有 小北,李光,赵鹤,张良 四个人。 其中 小北投了李光, 赵鹤投了小北, 张良投了李光, 李光投了张良。总结被投票数为: 李光:2票; 小北:1票,张良:1票. Check一下,一共有四个人投票了,被投的票是2(李光)+1(小北)+1(张良)=4,总结被投票数没有问题。 因此李光的票最多", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["all"], "content": "李光:2票; 小北:1票,张良:1票 .因此李光的票最多.【重要通知】本轮死亡玩家是李光",}]}\n\n##example##\n{"thought": "在上一轮中, 存活玩家有 小北,人类玩家,赵鹤,张良 四个人。 其中 小北投了人类玩家, 赵鹤投了小北, 张良投了小北, 人类玩家投了张良。总结被投票数为:小北:2票,人类玩家:1票,张良:0票 .Check一下,一共有四个人投票了,被投的票是2(小北)+1(人类玩家)+张良(0)=3,总结被投票数有问题。 更正总结被投票数为:小北:2票,人类玩家:1票,张良:1票。因此小北的票最多", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["all"], "content": "小北:2票,人类玩家:1票,张良:1票 .因此小北的票最多.【重要通知】本轮死亡玩家是小北",}]}\n\n\n##输出##\n请以列表的形式,给出参与者的所有行动。每个行动表示为JSON,格式为\n["thought": str, {"action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}, ...]\n关键词含义如下:\n_ thought (str): 主持人执行行动的一些思考,包括分析玩家的存活状态,对历史对话信息的理解,对当前任务情况的判断。 \n_ player_name (str): ***的 player_name,若行动方为主持人,为空,否则为玩家的 player_name;\n_ agent_name (str): ***的 agent_name,若为主持人,则 agent_name 为 "主持人",否则为玩家的 agent_name。\n_ content (str): 行动方的具体行为,若为主持人,content 为告知信息;否则,content 为玩家的具体行动。\n_ memory_tag (List[str]): 无论行动方是主持人还是玩家,memory_tag 固定为**所有**信息可见对象的agent_name, 如果信息可见对象为所有玩家,固定为 ["all"]\n##example##\n如果是玩家发言,则用 {"thought": "str", "action": {"agent_name": "agent_name_c", "player_name":"player_name_d"}, "observation": [{ "memory_tag":["agent_name_a","agent_name_b"],"content": "str"}]} 格式表示。content是玩家发出的信息;memory_tag是这条信息可见的对象,需要填写agent名。不要填写 agent_description\n如果agent_name是主持人,则无需输入player_name, 且observation变为 Dungeon_Master。即{"thought": "str", "action": {"agent_name": "主持人", "player_name":""}, "Dungeon_Master": [{ "memory_tag":["agent_name_a","agent_name_b"], "content": "str",}]}\n##注意事项##\n1. 所有玩家的座位、身份、agent_name、存活状态等信息在开头部分已给出。\n2. "observation" or "Dungeon_Master"如何选择?若 agent_name 为"主持人",则为"Dungeon_Master",否则为 "observation"。\n3. 输出列表的最后一个元素一定是{"action": "taskend"}。\n4. 整个list是一个jsonstr,请输出jsonstr,不用输出markdown格式\n5. 结合已有的步骤,每次只输出下一个步骤,即一个 {"thought": str, "action": {"player_name":str, "agent_name":str}, "observation" or "Dungeon_Master": [{"memory_tag":str,"content":str}]}\n6. 如果是人类玩家发言, 一定要选择类似 人类agent 这样的agent_name', +# 'name': '统计票数', +# 'accesscriteria': '', +# 'extra': '{"pattern": "react", "endcheck": "True", "memory_tag":"all","model_name":"gpt_4","dodisplay":"True"}'}}, +# {'id': '剧本杀/谁是卧底/智能交互/关键信息_3', +# 'type': 'opsgptkg_task', +# 'attributes': {'accesscriteria': '', +# 'extra': '{"ignorememory":"True","dodisplay":"True"}', +# 'executetype': '', +# 'description': '关键信息', +# 'name': '关键信息'}}, +# {'id': '剧本杀/谁是卧底/智能交互/判断游戏是否结束', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '判断游戏是否结束', +# 'name': '判断游戏是否结束', +# 'accesscriteria': '', +# 'extra': '{"memory_tag": "None","dodisplay":"True"}', +# 'executetype': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互/事实_1', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '是', 'name': '是', 'extra': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互/事实_2', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '否', 'name': '否', 'extra': ''}}, +# {'id': '剧本杀/谁是卧底/智能交互/给出每个人的单词以及最终胜利者', +# 'type': 'opsgptkg_task', +# 'attributes': {'extra': '{"dodisplay":"True"}', +# 'executetype': '', +# 'description': '给出每个人的单词以及最终胜利者', +# 'name': '给出每个人的单词以及最终胜利者', +# 'accesscriteria': ''}}, +# {'id': '剧本杀/狼人杀/智能交互/判断游戏是否结束', +# 'type': 'opsgptkg_task', +# 'attributes': {'description': '判断游戏是否结束 ', +# 'name': '判断游戏是否结束 ', +# 'accesscriteria': '', +# 'extra': '{"memory_tag": "None"}', +# 'executetype': ''}}, +# {'id': '剧本杀/狼人杀/智能交互/事实_2', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'extra': '', 'description': '否', 'name': '否'}}, +# {'id': '剧本杀/狼人杀/智能交互/事实_1', +# 'type': 'opsgptkg_phenomenon', +# 'attributes': {'description': '是', 'name': '是', 'extra': ''}}, +# {'id': '剧本杀/l狼人杀/智能交互/宣布游戏胜利者', +# 'type': 'opsgptkg_task', +# 'attributes': {'extra': '', +# 'executetype': '', +# 'description': '判断游戏是否结束', +# 'name': '判断游戏是否结束', +# 'accesscriteria': ''}}, +# {'id': '剧本杀', +# 'type': 'opsgptkg_intent', +# 'attributes': {'description': '文本游戏相关(如狼人杀等)', 'name': '剧本杀', 'extra': ''}}] + +# edges = [('剧本杀', '剧本杀/谁是卧底'), +# ('剧本杀', '剧本杀/狼人杀'), +# ('剧本杀/谁是卧底', '剧本杀/谁是卧底/智能交互'), +# ('剧本杀/狼人杀', '剧本杀/狼人杀/智能交互'), +# ('剧本杀/谁是卧底/智能交互', '剧本杀/谁是卧底/智能交互/分配座位'), +# ('剧本杀/狼人杀/智能交互', '剧本杀/狼人杀/智能交互/位置选择'), +# ('剧本杀/谁是卧底/智能交互/分配座位', '剧本杀/谁是卧底/智能交互/角色分配和单词分配'), +# ('剧本杀/狼人杀/智能交互/位置选择', '剧本杀/狼人杀/智能交互/角色选择'), +# ('剧本杀/谁是卧底/智能交互/角色分配和单词分配', '剧本杀/谁是卧底/智能交互/通知身份'), +# ('剧本杀/狼人杀/智能交互/角色选择', '剧本杀/狼人杀/智能交互/向玩家通知消息'), +# ('剧本杀/谁是卧底/智能交互/通知身份', '剧本杀/谁是卧底/智能交互/关键信息_1'), +# ('剧本杀/狼人杀/智能交互/向玩家通知消息', '剧本杀/狼人杀/智能交互/狼人时刻'), +# ('剧本杀/谁是卧底/智能交互/关键信息_1', '剧本杀/谁是卧底/智能交互/开始新一轮的讨论'), +# ('剧本杀/狼人杀/智能交互/狼人时刻', '剧本杀/狼人杀/智能交互/天亮讨论'), +# ('剧本杀/谁是卧底/智能交互/开始新一轮的讨论', '剧本杀/谁是卧底/智能交互/关键信息_2'), +# ('剧本杀/狼人杀/智能交互/天亮讨论', '剧本杀/狼人杀/智能交互/票选凶手'), +# ('剧本杀/谁是卧底/智能交互/关键信息_2', '剧本杀/谁是卧底/智能交互/票选卧底_1'), +# ('剧本杀/谁是卧底/智能交互/票选卧底_1', '剧本杀/谁是卧底/智能交互/关键信息_4'), +# ('剧本杀/谁是卧底/智能交互/关键信息_4', '剧本杀/谁是卧底/智能交互/统计票数'), +# ('剧本杀/谁是卧底/智能交互/统计票数', '剧本杀/谁是卧底/智能交互/关键信息_3'), +# ('剧本杀/谁是卧底/智能交互/关键信息_3', '剧本杀/谁是卧底/智能交互/判断游戏是否结束'), +# ('剧本杀/谁是卧底/智能交互/判断游戏是否结束', '剧本杀/谁是卧底/智能交互/事实_1'), +# ('剧本杀/谁是卧底/智能交互/判断游戏是否结束', '剧本杀/谁是卧底/智能交互/事实_2'), +# ('剧本杀/谁是卧底/智能交互/事实_1', '剧本杀/谁是卧底/智能交互/给出每个人的单词以及最终胜利者'), +# ('剧本杀/谁是卧底/智能交互/事实_2', '剧本杀/谁是卧底/智能交互/开始新一轮的讨论'), +# ('剧本杀/狼人杀/智能交互/票选凶手', '剧本杀/狼人杀/智能交互/判断游戏是否结束'), +# ('剧本杀/狼人杀/智能交互/判断游戏是否结束', '剧本杀/狼人杀/智能交互/事实_2'), +# ('剧本杀/狼人杀/智能交互/判断游戏是否结束', '剧本杀/狼人杀/智能交互/事实_1'), +# ('剧本杀/狼人杀/智能交互/事实_2', '剧本杀/狼人杀/智能交互/狼人时刻'), +# ('剧本杀/狼人杀/智能交互/事实_1', '剧本杀/l狼人杀/智能交互/宣布游戏胜利者'), +# ('ekg_team_default', '剧本杀') +# ] + + + +tools = [ + "谁是卧底-座位分配", "谁是卧底-角色分配", "谁是卧底-结果输出", "谁是卧底-胜利条件判断", + "谁是卧底-张伟", "谁是卧底-李静", "谁是卧底-王鹏", +] + +# tools = [ +# "狼人杀-角色分配工具", "狼人杀-座位分配", "狼人杀-胜利条件判断", "狼人杀-结果输出", +# '狼人杀-agent_朱丽', '狼人杀-agent_周杰', '狼人杀-agent_沈强', '狼人杀-agent_韩刚', +# '狼人杀-agent_梁军', '狼人杀-agent_周欣怡', '狼人杀-agent_贺子轩' +# ] + +AGENT_CONFIGS = { + "codefuse_function_caller": { + "config_name": "codefuse_function_caller", + "agent_type": "FunctioncallAgent", + "agent_name": "codefuse_function_caller", + "llm_config_name": "qwen_chat", + "tools": tools, + } +} +os.environ["AGENT_CONFIGS"] = json.dumps(AGENT_CONFIGS) + +project_config = get_ekg_project_config_from_env() +ekg = EKG(project_config=project_config, initialize_space=False) + + +# # 添加节点 +# for node in nodes: +# ekg.add_node(node) + +# # 添加边 +# for start_id, end_id in edges: +# ekg.add_edge(start_id, end_id) + +response = ekg.run("我要玩谁是卧底!",rootid="ekg_team_default") +for i in response: + pass diff --git a/tests/test_config.py.example b/tests/test_config.py.example index c1347ee..ac03016 100644 --- a/tests/test_config.py.example +++ b/tests/test_config.py.example @@ -1,6 +1,8 @@ import os, openai, base64 from loguru import logger +os.environ["DM_llm_name"] = 'Qwen2_72B_Instruct_OpsGPT' #or gpt_4 + # 兜底大模型配置 OPENAI_API_BASE = "https://api.openai.com/v1" os.environ["API_BASE_URL"] = OPENAI_API_BASE @@ -19,6 +21,78 @@ os.environ["gpt4-llm_temperature"] = "0.0" +MODEL_CONFIGS = { + # old llm config + "default": { + "model_name": "gpt-3.5-turbo", + "model_engine": "qwen", + "temperature": "0", + "api_key": "", + "api_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + }, + "codefuser":{ + "model_name": "gpt-4", + "model_engine": "openai", + "temperature": "0", + "api_key": "", + "api_base_url": OPENAI_API_BASE, + }, + # new llm config + "dashscope_chat": { + "model_type": "dashscope_chat", + "model_name": "qwen2.5-72b-instruct" , + "api_key": "", + }, + "moonshot_chat": { + "model_type": "moonshot_chat", + "model_name": "moonshot-v1-8k" , + "api_key": "", + }, + "ollama_chat": { + "model_type": "ollama_chat", + "model_name": "qwen2.5-0.5b", + "api_key": "", + }, + "openai_chat": { + "model_type": "openai_chat", + "model_name": "gpt-4", + "api_key": "", + }, + "qwen_chat": { + "model_type": "qwen_chat", + "model_name": "qwen2.5-72b-instruct", + "api_key": "", + }, + "yi_chat": { + "model_type": "yi_chat", + "model_name": "yi-lightning" , + "api_key": "", + }, + # embedding configs + "dashscope_text_embedding": { + "model_type": "dashscope_text_embedding", + "model_name": "text-embedding-v3", + "api_key": "", + }, + "ollama_embedding": { + "model_type": "ollama_embedding", + "model_name": "qwen2.5-0.5b", + "api_key": "", + }, + "openai_embedding": { + "model_type": "openai_embedding", + "model_name": "text-embedding-ada-002", + "api_key": "", + }, + "qwen_text_embedding": { + "model_type": "dashscope_text_embedding", + "model_name": "text-embedding-v3", + "api_key": "", + }, +} + +os.environ["MODEL_CONFIGS"] = json.dumps(MODEL_CONFIGS) + #### NebulaHandler #### os.environ['nb_host'] = 'graphd' os.environ['nb_port'] = '9669' @@ -42,6 +116,36 @@ os.environ['tb_definition_value'] = 'message_test_new' os.environ['tb_expire_time'] = '604800' #86400*7 +################# +## DB_CONFIGS ## +################# +DB_CONFIGS = { + "gb_config": { + "gb_type": "NebulaHandler", + "extra_kwargs": { + 'host':'graphd', + 'port': '9669', + 'username': os.environ['nb_username'], + 'password': os.environ['nb_password'], + 'space': "client" + } + }, + "tb_config": { + "tb_type": 'TBaseHandler', + "index_name": "opsgptkg", + "host": 'redis-stack', + "port": '6379', + "username": os.environ['tb_username'], + "password": os.environ['tb_password'], + "extra_kwargs": { + "definition_value": "opsgptkg", + "memory_definition_value": "opsgptkg_message" + } + } +} +os.environ["DB_CONFIGS"] = json.dumps(DB_CONFIGS) + + ######################################## ########## 以下参数暂不涉及无需配置 ######## diff --git a/tests/tools/get_tool.py b/tests/tools/get_tool.py new file mode 100644 index 0000000..b9155e9 --- /dev/null +++ b/tests/tools/get_tool.py @@ -0,0 +1,30 @@ +import os +from loguru import logger + +try: + import os, sys + src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + sys.path.append(src_dir) + import test_config +except Exception as e: + # set your config + logger.error(f"{e}") + + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from muagent import get_tool +from muagent.tools import toLangchainTools + + +tools = toLangchainTools([get_tool("Multiplier")]) + +print(get_tool("Multiplier").intput_to_json_schema()) +print(get_tool("Multiplier").output_to_json_schema()) +# tool run 测试 +print(tools[0].func(1,2)) \ No newline at end of file