Skip to content

Commit

Permalink
fix: add graphs.py for langgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Jun 30, 2024
1 parent 9a4612f commit 423f6aa
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 72 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ langchain_anthropic
langchain_community
langchain_experimental
langchainhub
langgraph
langsmith
https://huggingface.co/spacy/en_core_web_md/resolve/main/en_core_web_md-any-py3-none-any.whl
lxml
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def test():
ci_params.tools = []
ci_params.tools = CodeInterpreterTools(ci_params).get_all_tools()

agent = CodeInterpreterAgent.choose_agent_executor(ci_params=ci_params)
agent_executors = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
# agent = CodeInterpreterAgent.choose_single_chat_agent(ci_params=ci_params)
# agent = CodeInterpreterAgent.create_agent_and_executor_experimental(ci_params=ci_params)
result = agent.invoke({"input": sample})
result = agent_executors[0].invoke({"input": sample})
print("result=", result)


Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/brain/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test():
ci_params.tools = CodeInterpreterTools(ci_params).get_all_tools()
brain = CodeInterpreterBrain(ci_params)

if True:
if False:
# try1: agent_executor
print("try1: agent_executor")
brain.use_agent(AgentName.AGENT_EXECUTOR)
Expand All @@ -205,7 +205,7 @@ def test():
assert "python" == result.tool
assert "test output" in result.tool_input

if False:
if True:
# try3: supervisor
print("try3: supervisor")
sample = "ステップバイステップで2*5+2を計算して。"
Expand Down
77 changes: 66 additions & 11 deletions src/codeinterpreterapi/graphs/graphs.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,75 @@
import functools
import operator
from typing import Annotated, Sequence, TypedDict

from langchain.callbacks import StdOutCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.planners.planners import CodeInterpreterPlanner
from codeinterpreterapi.supervisors.supervisors import CodeInterpreterSupervisor
from codeinterpreterapi.test_prompts.test_prompt import TestPrompt


def agent_node(state, agent, name):
print(f"agent_node {name} node!")
print(" state=", state)
if "input" not in state:
state["input"] = state["question"]
result = agent.invoke(state)
print("agent_node result=", result)
if "output" in result:
state["messages"].append(str(result["output"]))
return state


def supervisor_node(state, supervisor, name):
print(f"supervisor_node {name} node!")
print(" state=", state)
result = supervisor.invoke(state)
print("supervisor_node type(result)=", type(result))
print("supervisor_node result=", result)

if result is None:
state["next"] = "FINISH"
elif isinstance(result, dict):
# if "output" in result:
# state["messages"].append(str(result["output"]))
if "next" in result:
state["next"] = result.next
state["messages"].append(f"次のagentは「{result.next}」です。")
elif hasattr(result, "next"):
# RouteSchema object
state["next"] = result.next
state["messages"].append(f"次のagentは「{result.next}」です。")
else:
state["next"] = "FINISH"

if state["next"] == "FINISH":
state["messages"].append("処理完了です。")
else:
next_agent = state["next"]
state["messages"].append(f"次のagentは「{next_agent}」です。")

return state


class CodeInterpreterStateGraph:
def __init__(self, ci_params: CodeInterpreterParams):
self.ci_params = ci_params
self.initialize_agent_info()
self.graph = self.initialize_graph()
self.app = self.graph.compile()
self.node_descriptions_dict = {}
self.node_agent_dict = {}
self.initialize_agent_info()
self.graph = self.initialize_graph()

def initialize_agent_info(self) -> None:
# 各エージェントに対してノードを作成
for agent_def in self.ci_params.agent_def_list:
agent = agent_def.agent
agent = agent_def.agent_executor
agent_name = agent_def.agent_name
agent_role = agent_def.agent_role

Expand All @@ -35,21 +82,27 @@ class GraphState(TypedDict):
# emb_model: HuggingFaceEmbeddings # Embeddingsモデル
question: str # 質問文
# documents: List[Document] # indexから取得したドキュメントのリスト
messages: Annotated[Sequence[BaseMessage], operator.add] # メッセージの履歴
messages: Annotated[Sequence[BaseMessage], operator.add] = []
# intermediate_steps: str = ""

def initialize_graph(self) -> StateGraph:
workflow = StateGraph(CodeInterpreterStateGraph.GraphState)

SUPERVISOR_AGENT_NAME = "supervisor_agent"
workflow.add_node(SUPERVISOR_AGENT_NAME, self.ci_params.supervisor_agent)
supervisor_node_replaced = functools.partial(
supervisor_node, supervisor=self.ci_params.supervisor_agent, name=SUPERVISOR_AGENT_NAME
)
workflow.add_node(SUPERVISOR_AGENT_NAME, supervisor_node_replaced)
for agent_name, agent in self.node_agent_dict.items():
workflow.add_node(agent_name, agent)
agent_node_replaced = functools.partial(agent_node, agent=agent, name=agent_name)
workflow.add_node(agent_name, agent_node_replaced)
workflow.add_edge(agent_name, SUPERVISOR_AGENT_NAME)

# The supervisor populates the "next" field in the graph state
# which routes to a node or finishes
conditional_map = {k: k for k, _ in self.node_descriptions_dict}
conditional_map = {k: k for k, _ in self.node_descriptions_dict.items()}
conditional_map["FINISH"] = END
print("conditional_map=", conditional_map)
workflow.add_conditional_edges(SUPERVISOR_AGENT_NAME, lambda x: x["next"], conditional_map)
# Finally, add entrypoint
workflow.set_entry_point(SUPERVISOR_AGENT_NAME)
Expand All @@ -60,19 +113,21 @@ def initialize_graph(self) -> StateGraph:

def run(self, input_data):
# グラフを実行
final_state = self.app.invoke(input_data)
final_state = self.graph.invoke(input_data)
return final_state


def test():
llm, llm_tools = prepare_test_llm()
config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]})
llm = llm.with_config(config)
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
_ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params)
_ = CodeInterpreterSupervisor.choose_supervisor(planner=planner, ci_params=ci_params)

sg = CodeInterpreterStateGraph(ci_params)
test_input = "pythonで円周率を表示するプログラムを実行してください。"
output = sg.invoke({"messages": [test_input]})
output = sg.run({"input": TestPrompt.svg_input_str, "messages": [TestPrompt.svg_input_str]})
print("output=", output)


Expand Down
164 changes: 107 additions & 57 deletions src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import getpass
import os
import platform
from typing import Union
from typing import Any, List, Union

from langchain.agents import AgentExecutor, AgentOutputParser, Tool
from langchain.agents import AgentExecutor, AgentOutputParser
from langchain.schema import AgentAction, AgentFinish
from langchain_core.output_parsers.openai_functions import PydanticOutputFunctionsParser
from langchain_core.outputs import Generation
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.planners.planners import CodeInterpreterPlanner
from codeinterpreterapi.supervisors.prompts import create_supervisor_agent_prompt
from codeinterpreterapi.test_prompts.test_prompt import TestPrompt


class CodeInterpreterSupervisor:
Expand All @@ -29,6 +34,7 @@ def choose_supervisor(planner: Runnable, ci_params: CodeInterpreterParams) -> Ag
members.append(agent_def.agent_name)

options = ["FINISH"] + members
print("options=", options)

# prompt
prompt = create_supervisor_agent_prompt(ci_params.is_ja)
Expand All @@ -37,81 +43,125 @@ def choose_supervisor(planner: Runnable, ci_params: CodeInterpreterParams) -> Ag
print("choose_supervisor prompt.input_variables=", input_variables)

# Using openai function calling can make output parsing easier for us
function_def = {
"name": "route",
"description": "Select the next role.",
"parameters": {
"title": "routeSchema",
"type": "object",
"properties": {
"next": {
"title": "Next",
"anyOf": [
{"enum": options},
],
}
},
"required": ["next"],
},
}

class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
print("parse llm_output=", llm_output)
if "Route" in llm_output:
next_action = llm_output.split("Route:")[-1].strip()
if next_action in options:
return AgentAction(tool="Route", tool_input=next_action, log=llm_output)
else:
return AgentFinish(
return_values={
"result": f"Invalid next action. Available options are: {', '.join(options)}"
},
log=llm_output,
)
# function_def = {
# "name": "route",
# "description": "Select the next role.",
# "parameters": {
# "title": "routeSchema",
# "type": "object",
# "properties": {
# "next": {
# "title": "Next",
# "anyOf": [
# {"enum": options},
# ],
# },
# "question": {
# "title": "Question",
# "type": "string",
# "description": "もともとの質問内容",
# },
# },
# "required": ["next", "question"],
# },
# }

class RouteSchema(BaseModel):
next: str = Field(..., description=f"The next route item. This is one of: {options}")
question: str = Field(..., description="The original question from user.")

class CustomOutputParserForGraph(AgentOutputParser):
def parse(self, text: str) -> dict:
print("CustomOutputParserForGraph parse text=", text)
if isinstance(text, str):
next_route = text
next_route = next_route.strip()
next_route = next_route.replace("'", "")
next_route = next_route.replace('"', "")
else:
return AgentFinish(
return_values={
"result": f"Agent did not select the Route tool. Available options are: {', '.join(options)}"
},
log=llm_output,
)

output_parser = CustomOutputParser()
next_route = "FINISH"
return_values = {
"next": next_route,
"question": "",
"messages": [],
"intermediate_steps": [],
}
return return_values

class CustomOutputParserForExecutor(CustomOutputParserForGraph):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
print("CustomOutputParserForExecutor parse text=", text)
return_values = super().parse(text)
return AgentFinish(
return_values=return_values,
log=text,
)

class CustomOutputParserForGraphPydantic(PydanticOutputFunctionsParser):
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
print("CustomOutputParserForGraphPydantic parse_result result=", result)
result = super().parse_result(result, partial)
print("CustomOutputParserForGraphPydantic parse_result result after=", result)
if isinstance(result, str):
next_route = result
next_route = next_route.strip()
next_route = next_route.replace("'", "")
next_route = next_route.replace('"', "")
else:
next_route = "FINISH"
return_values = {
"next": next_route,
"question": "",
"messages": [],
"intermediate_steps": [],
}
return return_values

# output_parser = CustomOutputParserForGraph()
# output_parser_for_executor = CustomOutputParserForExecutor()
# output_parser_pydantic = CustomOutputParserForGraphPydantic(pydantic_schema=RouteSchema)

# tool
def route(next_action: str) -> str:
if next_action not in options:
return f"Invalid next action. Available options are: {', '.join(options)}"
return f"The next action is: {next_action}"

tool = Tool(
name="Route",
func=route,
description="Select the next role from the available options.",
schema=function_def["parameters"],
)
# tool = Tool(
# name="Route",
# func=route,
# description="Select the next role from the available options.",
# # schema=function_def["parameters"],
# )

# config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]})

# agent
supervisor_agent = prompt | ci_params.llm | output_parser
ci_params.supervisor_agent = supervisor_agent
llm_with_structured_output = ci_params.llm.with_structured_output(RouteSchema)
# supervisor_agent = prompt | ci_params.llm | output_parser
# supervisor_agent_for_executor = prompt | ci_params.llm
supervisor_agent_structured_output = prompt | llm_with_structured_output

ci_params.supervisor_agent = supervisor_agent_structured_output

# agent_executor
agent_executor = AgentExecutor.from_agent_and_tools(
agent=supervisor_agent, tools=[tool], verbose=ci_params.verbose
)
# agent_executor = AgentExecutor.from_agent_and_tools(
# agent=supervisor_agent_for_executor,
# tools=[tool],
# verbose=ci_params.verbose,
# callbacks=[],
# )

return agent_executor
return supervisor_agent_structured_output


def test():
# sample = "ステップバイステップで2*5+2を計算して。"
sample = "pythonで円周率を表示するプログラムを実行してください。"
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
_ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params)
supervisor = CodeInterpreterSupervisor.choose_supervisor(planner=planner, ci_params=ci_params)
result = supervisor.invoke({"input": sample, "messages": [sample]})
result = supervisor.invoke({"input": TestPrompt.svg_input_str, "messages": [TestPrompt.svg_input_str]})
print("result=", result)


Expand Down
Loading

0 comments on commit 423f6aa

Please sign in to comment.