Skip to content

Commit

Permalink
fix: update agents, supervisors, graphs for coding
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Jul 6, 2024
1 parent 423f6aa commit cbec7ab
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.config import settings
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.test_prompts.test_prompt import TestPrompt
from codeinterpreterapi.tools.tools import CodeInterpreterTools


Expand Down Expand Up @@ -119,7 +120,6 @@ def create_single_chat_agent_executor(ci_params: CodeInterpreterParams) -> Agent


def test():
sample = "ツールのpythonで円周率を表示するプログラムを実行してください。"
# sample = "lsコマンドを実行してください。"
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
Expand All @@ -129,7 +129,7 @@ def test():
agent_executors = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
# agent = CodeInterpreterAgent.choose_single_chat_agent(ci_params=ci_params)
# agent = CodeInterpreterAgent.create_agent_and_executor_experimental(ci_params=ci_params)
result = agent_executors[0].invoke({"input": sample})
result = agent_executors[0].invoke({"input": TestPrompt.svg_input_str})
print("result=", result)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
agent_definition:
agent_name: "main関数を作成するエージェント"
agent_type: structured_chat
agent_type: tool_calling
agent_role: |
プログラムのmain関数を作成する作業をお願いします。
別のagentはプログラムの分割やテストの作成を実施することになります。
あなたの役割を理解して適切な回答をしてください。
あなたの役割は次の通りです。
・コードの作成
・コードの実行(toolを使って実行してエラーなく実行できるところまで確認してください)
・エラーがある場合の修正と再実行
次に続くシステムプロンプトを注意深く読んで正しくふるまってください。
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def load_structured_chat_agent_executor(
"""
prompt = create_structured_chat_agent_prompt(ci_params.is_ja)
if agent_def.agent_role is not None:
print("load_structured_chat_agent_executor prompt.partial agent_def.message_prompt_template")
prompt = prompt.partial(agent_role=agent_def.agent_role)
input_variables = prompt.input_variables
print("load_structured_chat_agent_executor prompt.input_variables=", input_variables)
print("load_structured_chat_agent_executor prompt=", prompt.messages)
if ci_params.verbose_prompt:
print("load_structured_chat_agent_executor prompt.input_variables=", input_variables)
print("load_structured_chat_agent_executor prompt=", prompt.messages)
agent = create_structured_chat_agent(
llm=ci_params.llm_tools,
tools=ci_params.tools,
Expand Down
10 changes: 6 additions & 4 deletions src/codeinterpreterapi/agents/tool_calling/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def load_tool_calling_agent_executor(
"""
prompt = create_tool_calling_agent_prompt(ci_params.is_ja)
if agent_def.message_prompt_template is not None:
print("load_tool_calling_agent_executor prompt.partial agent_def.message_prompt_template")
prompt = prompt.partial(agent_role=agent_def.agent_role)
input_variables = prompt.input_variables
print("load_tool_calling_agent_executor prompt.input_variables=", input_variables)
print("load_tool_calling_agent_executor prompt=", prompt.messages)
if ci_params.verbose_prompt:
input_variables = prompt.input_variables
print("load_tool_calling_agent_executor prompt.input_variables=", input_variables)
print("load_tool_calling_agent_executor prompt=", prompt.messages)
agent = create_tool_calling_agent(
llm=ci_params.llm,
tools=ci_params.tools,
Expand All @@ -41,6 +41,8 @@ def load_tool_calling_agent_executor(
def test():
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
ci_params.verbose = True
ci_params.verbose_prompt = False
agent_executor = load_tool_calling_agent_executor(ci_params)
test_input = "pythonで円周率を表示するプログラムを実行してください。"
agent_executor_output = agent_executor.invoke({"input": test_input})
Expand Down
8 changes: 6 additions & 2 deletions src/codeinterpreterapi/brain/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ class CodeInterpreterParams(BaseModel):
tools: Optional[List[BaseTool]] = []
callbacks: Optional[Callbacks] = None
verbose: Optional[bool] = False
verbose_prompt: Optional[bool] = False
is_local: Optional[bool] = True
is_ja: Optional[bool] = True
runnable_config: Optional[RunnableConfig] = None
agent_def_list: Optional[List[AgentDefinition]] = []
supervisor_agent: Optional[Runnable] = None

@classmethod
def get_test_params(cls, llm: BaseLanguageModel, llm_tools: BaseChatModel = None):
def get_test_params(
cls, llm: BaseLanguageModel, llm_tools: BaseChatModel = None, runnable_config: RunnableConfig = None
):
tools = [test_plus, test_multiply]
configurable = {"session_id": "123"}
runnable_config = RunnableConfig(configurable=configurable)
if RunnableConfig is None:
runnable_config = RunnableConfig(configurable=configurable)
return CodeInterpreterParams(
llm_lite=llm,
llm_fast=llm,
Expand Down
29 changes: 16 additions & 13 deletions src/codeinterpreterapi/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import operator
from typing import Annotated, Sequence, TypedDict

from langchain.callbacks import StdOutCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
Expand All @@ -17,35 +15,42 @@

def agent_node(state, agent, name):
print(f"agent_node {name} node!")
print(" state=", state)
if "input" not in state:
state["input"] = state["question"]
result = agent.invoke(state)
print("agent_node result=", result)
print(" state keys=", state.keys())
inputs = state
if "input" not in inputs:
# inputs["input"] = state["question"]
inputs["input"] = str(state["messages"])
# inputs["agent_scratchpad"] = str(state["messages"])
result = agent.invoke(inputs)
print("agent_node type(result)=", type(result))
if "output" in result:
state["messages"].append(str(result["output"]))
return state


def supervisor_node(state, supervisor, name):
print(f"supervisor_node {name} node!")
print(" state=", state)
print(" state keys=", state.keys())
result = supervisor.invoke(state)
print("supervisor_node type(result)=", type(result))
print("supervisor_node result=", result)
# print("supervisor_node result=", result)

state["question"] = state["messages"][0]
if result is None:
state["next"] = "FINISH"
elif isinstance(result, dict):
print("supervisor_node type(result)=", type(result))
# if "output" in result:
# state["messages"].append(str(result["output"]))
if "next" in result:
state["next"] = result.next
state["next"] = result["next"]
print("supervisor_node result(dict) next=", result["next"])
state["messages"].append(f"次のagentは「{result.next}」です。")
elif hasattr(result, "next"):
# RouteSchema object
state["next"] = result.next
state["messages"].append(f"次のagentは「{result.next}」です。")
print("supervisor_node result(RouteSchema) next=", result.next)
else:
state["next"] = "FINISH"

Expand Down Expand Up @@ -119,14 +124,12 @@ def run(self, input_data):

def test():
llm, llm_tools = prepare_test_llm()
config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]})
llm = llm.with_config(config)
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
_ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params)
_ = CodeInterpreterSupervisor.choose_supervisor(planner=planner, ci_params=ci_params)

sg = CodeInterpreterStateGraph(ci_params)
sg = CodeInterpreterStateGraph(ci_params=ci_params)
output = sg.run({"input": TestPrompt.svg_input_str, "messages": [TestPrompt.svg_input_str]})
print("output=", output)

Expand Down
11 changes: 9 additions & 2 deletions src/codeinterpreterapi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from google.ai.generativelanguage_v1beta.types import GenerateContentRequest
from google.generativeai.types.content_types import FunctionDeclarationType # type: ignore[import]
from langchain.callbacks import StdOutCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore
from langchain_google_genai._common import SafetySettingDict
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDictLike
Expand Down Expand Up @@ -148,4 +149,10 @@ def get_llms(cls, model: str = settings.MODEL_LOCAL) -> List[BaseChatModel]:


def prepare_test_llm():
return CodeInterpreterLlm.get_llm_switcher(), CodeInterpreterLlm.get_llm_switcher_tools()
llm = CodeInterpreterLlm.get_llm_switcher()
llm_tools = CodeInterpreterLlm.get_llm_switcher_tools()
runnable_config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]})
llm = llm.with_config(runnable_config)
llm_tools = llm_tools.with_config(runnable_config)

return llm, llm_tools
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def choose_supervisor(planner: Runnable, ci_params: CodeInterpreterParams) -> Ag

class RouteSchema(BaseModel):
next: str = Field(..., description=f"The next route item. This is one of: {options}")
question: str = Field(..., description="The original question from user.")
# question: str = Field(..., description="The original question from user.")

class CustomOutputParserForGraph(AgentOutputParser):
def parse(self, text: str) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions src/codeinterpreterapi/test_prompts/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ class TestPrompt:
Python3とlxmlライブラリを使用して実装すること
コードはモジュール化し、再利用性と保守性を高めること
"""

ls_command_str = "lsコマンドを実行してください。"

0 comments on commit cbec7ab

Please sign in to comment.