From cbec7abdddc6c9e3be25f588d29c8a87c5fecff7 Mon Sep 17 00:00:00 2001 From: jinno Date: Sat, 6 Jul 2024 09:06:12 +0900 Subject: [PATCH] fix: update agents, supervisors, graphs for coding --- src/codeinterpreterapi/agents/agents.py | 4 +-- .../agents/config/coding_agent/config.yaml | 7 ++++- .../agents/structured_chat/agent_executor.py | 6 ++-- .../agents/tool_calling/agent_executor.py | 10 ++++--- src/codeinterpreterapi/brain/params.py | 8 +++-- src/codeinterpreterapi/graphs/graphs.py | 29 ++++++++++--------- src/codeinterpreterapi/llm/llm.py | 11 +++++-- .../supervisors/supervisors.py | 2 +- .../test_prompts/test_prompt.py | 2 ++ 9 files changed, 51 insertions(+), 28 deletions(-) diff --git a/src/codeinterpreterapi/agents/agents.py b/src/codeinterpreterapi/agents/agents.py index f94b3b41..cb72fdc1 100644 --- a/src/codeinterpreterapi/agents/agents.py +++ b/src/codeinterpreterapi/agents/agents.py @@ -16,6 +16,7 @@ from codeinterpreterapi.brain.params import CodeInterpreterParams from codeinterpreterapi.config import settings from codeinterpreterapi.llm.llm import prepare_test_llm +from codeinterpreterapi.test_prompts.test_prompt import TestPrompt from codeinterpreterapi.tools.tools import CodeInterpreterTools @@ -119,7 +120,6 @@ def create_single_chat_agent_executor(ci_params: CodeInterpreterParams) -> Agent def test(): - sample = "ツールのpythonで円周率を表示するプログラムを実行してください。" # sample = "lsコマンドを実行してください。" llm, llm_tools = prepare_test_llm() ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) @@ -129,7 +129,7 @@ def test(): agent_executors = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params) # agent = CodeInterpreterAgent.choose_single_chat_agent(ci_params=ci_params) # agent = CodeInterpreterAgent.create_agent_and_executor_experimental(ci_params=ci_params) - result = agent_executors[0].invoke({"input": sample}) + result = agent_executors[0].invoke({"input": TestPrompt.svg_input_str}) print("result=", result) diff --git a/src/codeinterpreterapi/agents/config/coding_agent/config.yaml b/src/codeinterpreterapi/agents/config/coding_agent/config.yaml index 2c6fb0a6..c653757a 100644 --- a/src/codeinterpreterapi/agents/config/coding_agent/config.yaml +++ b/src/codeinterpreterapi/agents/config/coding_agent/config.yaml @@ -1,9 +1,14 @@ agent_definition: agent_name: "main関数を作成するエージェント" - agent_type: structured_chat + agent_type: tool_calling agent_role: | プログラムのmain関数を作成する作業をお願いします。 別のagentはプログラムの分割やテストの作成を実施することになります。 あなたの役割を理解して適切な回答をしてください。 + あなたの役割は次の通りです。 + ・コードの作成 + ・コードの実行(toolを使って実行してエラーなく実行できるところまで確認してください) + ・エラーがある場合の修正と再実行 + 次に続くシステムプロンプトを注意深く読んで正しくふるまってください。 diff --git a/src/codeinterpreterapi/agents/structured_chat/agent_executor.py b/src/codeinterpreterapi/agents/structured_chat/agent_executor.py index ff8dcaa3..88cfdd9a 100644 --- a/src/codeinterpreterapi/agents/structured_chat/agent_executor.py +++ b/src/codeinterpreterapi/agents/structured_chat/agent_executor.py @@ -18,11 +18,11 @@ def load_structured_chat_agent_executor( """ prompt = create_structured_chat_agent_prompt(ci_params.is_ja) if agent_def.agent_role is not None: - print("load_structured_chat_agent_executor prompt.partial agent_def.message_prompt_template") prompt = prompt.partial(agent_role=agent_def.agent_role) input_variables = prompt.input_variables - print("load_structured_chat_agent_executor prompt.input_variables=", input_variables) - print("load_structured_chat_agent_executor prompt=", prompt.messages) + if ci_params.verbose_prompt: + print("load_structured_chat_agent_executor prompt.input_variables=", input_variables) + print("load_structured_chat_agent_executor prompt=", prompt.messages) agent = create_structured_chat_agent( llm=ci_params.llm_tools, tools=ci_params.tools, diff --git a/src/codeinterpreterapi/agents/tool_calling/agent_executor.py b/src/codeinterpreterapi/agents/tool_calling/agent_executor.py index 0aa77849..516a79a0 100644 --- a/src/codeinterpreterapi/agents/tool_calling/agent_executor.py +++ b/src/codeinterpreterapi/agents/tool_calling/agent_executor.py @@ -18,11 +18,11 @@ def load_tool_calling_agent_executor( """ prompt = create_tool_calling_agent_prompt(ci_params.is_ja) if agent_def.message_prompt_template is not None: - print("load_tool_calling_agent_executor prompt.partial agent_def.message_prompt_template") prompt = prompt.partial(agent_role=agent_def.agent_role) - input_variables = prompt.input_variables - print("load_tool_calling_agent_executor prompt.input_variables=", input_variables) - print("load_tool_calling_agent_executor prompt=", prompt.messages) + if ci_params.verbose_prompt: + input_variables = prompt.input_variables + print("load_tool_calling_agent_executor prompt.input_variables=", input_variables) + print("load_tool_calling_agent_executor prompt=", prompt.messages) agent = create_tool_calling_agent( llm=ci_params.llm, tools=ci_params.tools, @@ -41,6 +41,8 @@ def load_tool_calling_agent_executor( def test(): llm, llm_tools = prepare_test_llm() ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + ci_params.verbose = True + ci_params.verbose_prompt = False agent_executor = load_tool_calling_agent_executor(ci_params) test_input = "pythonで円周率を表示するプログラムを実行してください。" agent_executor_output = agent_executor.invoke({"input": test_input}) diff --git a/src/codeinterpreterapi/brain/params.py b/src/codeinterpreterapi/brain/params.py index 0fbb3a13..e046fa6c 100644 --- a/src/codeinterpreterapi/brain/params.py +++ b/src/codeinterpreterapi/brain/params.py @@ -36,6 +36,7 @@ class CodeInterpreterParams(BaseModel): tools: Optional[List[BaseTool]] = [] callbacks: Optional[Callbacks] = None verbose: Optional[bool] = False + verbose_prompt: Optional[bool] = False is_local: Optional[bool] = True is_ja: Optional[bool] = True runnable_config: Optional[RunnableConfig] = None @@ -43,10 +44,13 @@ class CodeInterpreterParams(BaseModel): supervisor_agent: Optional[Runnable] = None @classmethod - def get_test_params(cls, llm: BaseLanguageModel, llm_tools: BaseChatModel = None): + def get_test_params( + cls, llm: BaseLanguageModel, llm_tools: BaseChatModel = None, runnable_config: RunnableConfig = None + ): tools = [test_plus, test_multiply] configurable = {"session_id": "123"} - runnable_config = RunnableConfig(configurable=configurable) + if RunnableConfig is None: + runnable_config = RunnableConfig(configurable=configurable) return CodeInterpreterParams( llm_lite=llm, llm_fast=llm, diff --git a/src/codeinterpreterapi/graphs/graphs.py b/src/codeinterpreterapi/graphs/graphs.py index 533cb317..cb612127 100644 --- a/src/codeinterpreterapi/graphs/graphs.py +++ b/src/codeinterpreterapi/graphs/graphs.py @@ -2,9 +2,7 @@ import operator from typing import Annotated, Sequence, TypedDict -from langchain.callbacks import StdOutCallbackHandler from langchain_core.messages import BaseMessage -from langchain_core.runnables import RunnableConfig from langgraph.graph import END, StateGraph from codeinterpreterapi.agents.agents import CodeInterpreterAgent @@ -17,11 +15,14 @@ def agent_node(state, agent, name): print(f"agent_node {name} node!") - print(" state=", state) - if "input" not in state: - state["input"] = state["question"] - result = agent.invoke(state) - print("agent_node result=", result) + print(" state keys=", state.keys()) + inputs = state + if "input" not in inputs: + # inputs["input"] = state["question"] + inputs["input"] = str(state["messages"]) + # inputs["agent_scratchpad"] = str(state["messages"]) + result = agent.invoke(inputs) + print("agent_node type(result)=", type(result)) if "output" in result: state["messages"].append(str(result["output"])) return state @@ -29,23 +30,27 @@ def agent_node(state, agent, name): def supervisor_node(state, supervisor, name): print(f"supervisor_node {name} node!") - print(" state=", state) + print(" state keys=", state.keys()) result = supervisor.invoke(state) print("supervisor_node type(result)=", type(result)) - print("supervisor_node result=", result) + # print("supervisor_node result=", result) + state["question"] = state["messages"][0] if result is None: state["next"] = "FINISH" elif isinstance(result, dict): + print("supervisor_node type(result)=", type(result)) # if "output" in result: # state["messages"].append(str(result["output"])) if "next" in result: - state["next"] = result.next + state["next"] = result["next"] + print("supervisor_node result(dict) next=", result["next"]) state["messages"].append(f"次のagentは「{result.next}」です。") elif hasattr(result, "next"): # RouteSchema object state["next"] = result.next state["messages"].append(f"次のagentは「{result.next}」です。") + print("supervisor_node result(RouteSchema) next=", result.next) else: state["next"] = "FINISH" @@ -119,14 +124,12 @@ def run(self, input_data): def test(): llm, llm_tools = prepare_test_llm() - config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]}) - llm = llm.with_config(config) ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) _ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params) planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params) _ = CodeInterpreterSupervisor.choose_supervisor(planner=planner, ci_params=ci_params) - sg = CodeInterpreterStateGraph(ci_params) + sg = CodeInterpreterStateGraph(ci_params=ci_params) output = sg.run({"input": TestPrompt.svg_input_str, "messages": [TestPrompt.svg_input_str]}) print("output=", output) diff --git a/src/codeinterpreterapi/llm/llm.py b/src/codeinterpreterapi/llm/llm.py index 24c25fd5..5ae8a33b 100644 --- a/src/codeinterpreterapi/llm/llm.py +++ b/src/codeinterpreterapi/llm/llm.py @@ -2,9 +2,10 @@ from google.ai.generativelanguage_v1beta.types import GenerateContentRequest from google.generativeai.types.content_types import FunctionDeclarationType # type: ignore[import] +from langchain.callbacks import StdOutCallbackHandler from langchain.chat_models.base import BaseChatModel from langchain_core.messages import BaseMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore from langchain_google_genai._common import SafetySettingDict from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDictLike @@ -148,4 +149,10 @@ def get_llms(cls, model: str = settings.MODEL_LOCAL) -> List[BaseChatModel]: def prepare_test_llm(): - return CodeInterpreterLlm.get_llm_switcher(), CodeInterpreterLlm.get_llm_switcher_tools() + llm = CodeInterpreterLlm.get_llm_switcher() + llm_tools = CodeInterpreterLlm.get_llm_switcher_tools() + runnable_config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]}) + llm = llm.with_config(runnable_config) + llm_tools = llm_tools.with_config(runnable_config) + + return llm, llm_tools diff --git a/src/codeinterpreterapi/supervisors/supervisors.py b/src/codeinterpreterapi/supervisors/supervisors.py index 36d0e7c0..829916ec 100644 --- a/src/codeinterpreterapi/supervisors/supervisors.py +++ b/src/codeinterpreterapi/supervisors/supervisors.py @@ -68,7 +68,7 @@ def choose_supervisor(planner: Runnable, ci_params: CodeInterpreterParams) -> Ag class RouteSchema(BaseModel): next: str = Field(..., description=f"The next route item. This is one of: {options}") - question: str = Field(..., description="The original question from user.") + # question: str = Field(..., description="The original question from user.") class CustomOutputParserForGraph(AgentOutputParser): def parse(self, text: str) -> dict: diff --git a/src/codeinterpreterapi/test_prompts/test_prompt.py b/src/codeinterpreterapi/test_prompts/test_prompt.py index 0235601f..b9e1a3ef 100644 --- a/src/codeinterpreterapi/test_prompts/test_prompt.py +++ b/src/codeinterpreterapi/test_prompts/test_prompt.py @@ -30,3 +30,5 @@ class TestPrompt: Python3とlxmlライブラリを使用して実装すること コードはモジュール化し、再利用性と保守性を高めること """ + + ls_command_str = "lsコマンドを実行してください。"