diff --git a/src/codeinterpreterapi/agents/agents.py b/src/codeinterpreterapi/agents/agents.py index cb72fdc1..5e00c6b5 100644 --- a/src/codeinterpreterapi/agents/agents.py +++ b/src/codeinterpreterapi/agents/agents.py @@ -112,7 +112,7 @@ def create_single_chat_agent_executor(ci_params: CodeInterpreterParams) -> Agent # return_messages=True, # chat_memory=chat_memory, # ), - callbacks=ci_params.callbacks, + # callbacks=ci_params.callbacks, ) print("agent_executor.input_keys", agent_executor.input_keys) print("agent_executor.output_keys", agent_executor.output_keys) @@ -120,15 +120,10 @@ def create_single_chat_agent_executor(ci_params: CodeInterpreterParams) -> Agent def test(): - # sample = "lsコマンドを実行してください。" - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) - ci_params.tools = [] + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) ci_params.tools = CodeInterpreterTools(ci_params).get_all_tools() - agent_executors = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params) - # agent = CodeInterpreterAgent.choose_single_chat_agent(ci_params=ci_params) - # agent = CodeInterpreterAgent.create_agent_and_executor_experimental(ci_params=ci_params) result = agent_executors[0].invoke({"input": TestPrompt.svg_input_str}) print("result=", result) diff --git a/src/codeinterpreterapi/agents/config/coding_agent/config.yaml b/src/codeinterpreterapi/agents/config/coding_agent/config.yaml index c653757a..2416e9cf 100644 --- a/src/codeinterpreterapi/agents/config/coding_agent/config.yaml +++ b/src/codeinterpreterapi/agents/config/coding_agent/config.yaml @@ -1,5 +1,5 @@ agent_definition: - agent_name: "main関数を作成するエージェント" + agent_name: "main_function_create_agent" agent_type: tool_calling agent_role: | プログラムのmain関数を作成する作業をお願いします。 @@ -8,7 +8,7 @@ agent_definition: あなたの役割は次の通りです。 ・コードの作成 - ・コードの実行(toolを使って実行してエラーなく実行できるところまで確認してください) + ・コードの実行(toolを使ってコードを実行できます。エラーなく実行できることを確認してください。) ・エラーがある場合の修正と再実行 次に続くシステムプロンプトを注意深く読んで正しくふるまってください。 diff --git a/src/codeinterpreterapi/agents/config/fix_agent/config.yaml b/src/codeinterpreterapi/agents/config/fix_agent/config.yaml index 4d4bda91..2f5dfedf 100644 --- a/src/codeinterpreterapi/agents/config/fix_agent/config.yaml +++ b/src/codeinterpreterapi/agents/config/fix_agent/config.yaml @@ -1,5 +1,5 @@ agent_definition: - agent_name: "プログラムを修正するエージェント" + agent_name: "code_fix_agent" agent_type: structured_chat agent_role: | プログラムを修正する作業をお願いします。 diff --git a/src/codeinterpreterapi/agents/config/split_agent/config.yaml b/src/codeinterpreterapi/agents/config/split_agent/config.yaml index 34b78594..534b5511 100644 --- a/src/codeinterpreterapi/agents/config/split_agent/config.yaml +++ b/src/codeinterpreterapi/agents/config/split_agent/config.yaml @@ -1,5 +1,5 @@ agent_definition: - agent_name: "プログラムを分割するエージェント" + agent_name: "code_split_agent" agent_type: structured_chat agent_role: | プログラムを分割する作業をお願いします。 diff --git a/src/codeinterpreterapi/agents/structured_chat/agent.py b/src/codeinterpreterapi/agents/structured_chat/agent.py index d43a316e..a1876114 100644 --- a/src/codeinterpreterapi/agents/structured_chat/agent.py +++ b/src/codeinterpreterapi/agents/structured_chat/agent.py @@ -176,8 +176,8 @@ def create_structured_chat_agent( def test(): - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) prompt = create_structured_chat_agent_prompt(ci_params.is_ja) agent = create_structured_chat_agent_wrapper(ci_params=ci_params, prompt=prompt) test_input = "pythonで円周率を表示するプログラムを実行してください。" diff --git a/src/codeinterpreterapi/agents/structured_chat/agent_executor.py b/src/codeinterpreterapi/agents/structured_chat/agent_executor.py index 88cfdd9a..00ef7c22 100644 --- a/src/codeinterpreterapi/agents/structured_chat/agent_executor.py +++ b/src/codeinterpreterapi/agents/structured_chat/agent_executor.py @@ -39,8 +39,8 @@ def load_structured_chat_agent_executor( def test(): - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) agent_executor = load_structured_chat_agent_executor(ci_params) test_input = "pythonで円周率を表示するプログラムを実行してください。" agent_executor_output = agent_executor.invoke({"input": test_input}) diff --git a/src/codeinterpreterapi/agents/tool_calling/agent.py b/src/codeinterpreterapi/agents/tool_calling/agent.py index ba397e5b..1427938e 100644 --- a/src/codeinterpreterapi/agents/tool_calling/agent.py +++ b/src/codeinterpreterapi/agents/tool_calling/agent.py @@ -111,8 +111,8 @@ def magic_function(input: int) -> int: def test(): - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) prompt = create_tool_calling_agent_prompt(ci_params.is_ja) agent = create_tool_calling_agent_wrapper(ci_params=ci_params, prompt=prompt) test_input = "pythonで円周率を表示するプログラムを実行してください。" diff --git a/src/codeinterpreterapi/agents/tool_calling/agent_executor.py b/src/codeinterpreterapi/agents/tool_calling/agent_executor.py index 516a79a0..0bbd745e 100644 --- a/src/codeinterpreterapi/agents/tool_calling/agent_executor.py +++ b/src/codeinterpreterapi/agents/tool_calling/agent_executor.py @@ -39,8 +39,8 @@ def load_tool_calling_agent_executor( def test(): - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) ci_params.verbose = True ci_params.verbose_prompt = False agent_executor = load_tool_calling_agent_executor(ci_params) diff --git a/src/codeinterpreterapi/brain/brain.py b/src/codeinterpreterapi/brain/brain.py index 89632dde..c507bcf0 100644 --- a/src/codeinterpreterapi/brain/brain.py +++ b/src/codeinterpreterapi/brain/brain.py @@ -175,8 +175,8 @@ def use_agent(self, new_agent: AgentName): def test(): settings.WORK_DIR = "/tmp" - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) ci_params.tools = [] ci_params.tools = CodeInterpreterTools(ci_params).get_all_tools() brain = CodeInterpreterBrain(ci_params) diff --git a/src/codeinterpreterapi/callbacks/callbacks.py b/src/codeinterpreterapi/callbacks/callbacks.py new file mode 100644 index 00000000..8735085c --- /dev/null +++ b/src/codeinterpreterapi/callbacks/callbacks.py @@ -0,0 +1,254 @@ +import sys +from copy import deepcopy +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from langchain.callbacks import StdOutCallbackHandler +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.messages import BaseMessage + + +class CustomStdOutCallbackHandler(StdOutCallbackHandler): + def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we are entering a chain.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201 + print(f"inputs={inputs}") + + def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any: + """Print out that we are entering a chain.""" + print(f"on_tool_start input_str={input_str}") + + def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any: + """Run when Chat Model starts running.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + print(f"\n\n\033[1m> Entering new {class_name} on_chat_model_start...\033[0m") # noqa: T201 + + +def get_current_function_name(depth: int = 1) -> str: + return sys._getframe(depth).f_code.co_name + + +def show_callback_info(name: str, tag: str, data: Any) -> None: + current_function_name = get_current_function_name(2) + print("show_callback_info current_function_name=", current_function_name, name) + print(f"{tag}=", trim_data(data)) + + +def trim_data(data: Union[Any, List[Any], Dict[str, Any]]) -> str: + """ + dataの構造をデバッグ表示用に短縮する関数 + + :data: 対象データ + """ + data_copy = deepcopy(data) + return trim_data_iter("", data_copy) + + +def trim_data_iter(indent: str, data: Union[Any, List[Any], Dict[str, Any]]) -> str: + """ + dataの構造をデバッグ表示用に短縮する関数 + + :param data: 対象データ + """ + indent_next = indent + " " + if isinstance(data, dict): + return trim_data_dict(indent_next, data) + elif isinstance(data, list): + return trim_data_array(indent_next, data) + else: + return trim_data_other(indent, data) + + +def trim_data_dict(indent: str, data: Dict[str, Any]) -> str: + """ + dataの構造をデバッグ表示用に短縮する関数 + + :param indent: インデント文字列 + :param data: 対象データ + """ + new_data_list = [] + for k, v in data.items(): + new_data_list.append(f"{indent}dict[{k}]: " + trim_data_iter(indent, v)) + return "\n".join(new_data_list) + + +def trim_data_array(indent: str, data: List[Any]) -> str: + """ + dataの構造をデバッグ表示用に短縮する関数 + + :param indent: インデント文字列 + :param data: 対象データ + """ + new_data_list = [] + for i, item in enumerate(data): + print(f"{indent}array[{str(i)}]: ") + new_data_list.append(trim_data_iter(indent, item)) + return "\n".join(new_data_list) + + +def trim_data_other(indent: str, data: Any) -> str: + """ + dataの構造をデバッグ表示用に短縮する関数 + + :param indent: インデント文字列 + :param data: 対象データ + """ + stype = str(type(data)) + s = str(data) + return f"{indent}type={stype}, data={s[:80]}" + + +class FullOutCallbackHandler(CustomStdOutCallbackHandler): + # CallbackManagerMixin + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM starts running. + + **ATTENTION**: This method is called for non-chat models (regular LLMs). If + you're implementing a handler for a chat model, + you should use on_chat_model_start instead. + """ + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + show_callback_info(class_name, "prompts", prompts) + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running. + + **ATTENTION**: This method is called for chat models. If you're implementing + a handler for a non-chat model, you should use on_llm_start instead. + """ + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + show_callback_info(class_name, "messages", messages) + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when Retriever starts running.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + show_callback_info(class_name, "query", query) + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when chain starts running.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + show_callback_info(class_name, "inputs", inputs) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when tool starts running.""" + class_name = serialized.get("name", serialized.get("id", [""])[-1]) + show_callback_info(class_name, "input_str", input_str) + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain ends running.""" + show_callback_info("no_name", "outputs", outputs) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain errors.""" + show_callback_info("no_name", "error", error) + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + show_callback_info("no_name", "action", action) + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent end.""" + show_callback_info("no_name", "finish", finish) + + # ToolManagerMixin + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool ends running.""" + show_callback_info("no_name", "output", output) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool errors.""" + show_callback_info("no_name", "error", error) diff --git a/src/codeinterpreterapi/chains/modifications_check.py b/src/codeinterpreterapi/chains/modifications_check.py index 6fd68056..9d615651 100644 --- a/src/codeinterpreterapi/chains/modifications_check.py +++ b/src/codeinterpreterapi/chains/modifications_check.py @@ -47,7 +47,7 @@ async def aget_file_modifications( async def test() -> None: - llm, llm_tools = prepare_test_llm() + llm, llm_tools, runnable_config = prepare_test_llm() code = """ import matplotlib.pyplot as plt diff --git a/src/codeinterpreterapi/graphs/agent_wrapper_tool.py b/src/codeinterpreterapi/graphs/agent_wrapper_tool.py new file mode 100644 index 00000000..9cef4381 --- /dev/null +++ b/src/codeinterpreterapi/graphs/agent_wrapper_tool.py @@ -0,0 +1,62 @@ +from typing import Any, Optional, Type + +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import BaseTool + +from codeinterpreterapi.brain.params import CodeInterpreterParams + + +class CustomFunctionInput(BaseModel): + question: str = Field(description="the original question to response user.") + message: str = Field(description="response message from this tool.") + + +class AgentWrapperTool(BaseTool): + """Tool that wraps an agent and exposes it as a LangChain tool.""" + + agent_impl: Optional[Any] = None + args_schema: Type[BaseModel] = CustomFunctionInput + + @classmethod + def create_agent_wrapper_tools(cls, ci_params: CodeInterpreterParams) -> None: + # 各エージェントに対してノードを作成 + tools = [] + for agent_def in ci_params.agent_def_list: + agent = agent_def.agent_executor + agent_name = agent_def.agent_name + agent_role = agent_def.agent_role + tool = cls(agent_name=agent_name, agent_role=agent_role, agent_impl=agent) + tools.append(tool) + + return tools + + def __init__(self, agent_name: str, agent_role: str, agent_impl: Any): + super().__init__( + name=agent_name, + description=f"A tool that wraps the {agent_name} agent with role: {agent_role}. " + f"Input should be a query or task for the {agent_name} agent.", + ) + self.agent_impl = agent_impl + + def _run( + self, + question: str, + message: str, + ) -> str: + """Use the tool.""" + messages = [] + # messages.append(HumanMessage(question)) + messages.append(AIMessage(message)) + return self.agent_impl.invoke({"input": question, "messages": messages}) + + async def _arun( + self, + question: str, + message: str, + ) -> str: + """Use the tool.""" + messages = [] + messages.append(HumanMessage(question)) + messages.append(AIMessage(message)) + return await self.agent_impl.ainvoke(messages) diff --git a/src/codeinterpreterapi/graphs/graphs.py b/src/codeinterpreterapi/graphs/graphs.py index cb612127..53a4b47d 100644 --- a/src/codeinterpreterapi/graphs/graphs.py +++ b/src/codeinterpreterapi/graphs/graphs.py @@ -2,11 +2,13 @@ import operator from typing import Annotated, Sequence, TypedDict -from langchain_core.messages import BaseMessage +from langchain_core.messages import AnyMessage, BaseMessage from langgraph.graph import END, StateGraph from codeinterpreterapi.agents.agents import CodeInterpreterAgent from codeinterpreterapi.brain.params import CodeInterpreterParams +from codeinterpreterapi.callbacks.callbacks import show_callback_info +from codeinterpreterapi.graphs.tool_node.tool_node import create_agent_nodes from codeinterpreterapi.llm.llm import prepare_test_llm from codeinterpreterapi.planners.planners import CodeInterpreterPlanner from codeinterpreterapi.supervisors.supervisors import CodeInterpreterSupervisor @@ -63,36 +65,54 @@ def supervisor_node(state, supervisor, name): return state +# グラフで使用する変数(状態)を定義 +class GraphState(TypedDict): + # llm_bind_tool: BaseLLM # ツールが紐付けされたllmモデル + # emb_model: HuggingFaceEmbeddings # Embeddingsモデル + question: str # 質問文 + # documents: List[Document] # indexから取得したドキュメントのリスト + messages: Annotated[Sequence[BaseMessage], operator.add] = [] + # intermediate_steps: str = "" + + +def should_end(state: GraphState): + last_message = state["messages"][-1] + return "FINAL ANSWER:" in last_message.content + + class CodeInterpreterStateGraph: def __init__(self, ci_params: CodeInterpreterParams): self.ci_params = ci_params self.node_descriptions_dict = {} self.node_agent_dict = {} - self.initialize_agent_info() + # self.initialize_agent_info() self.graph = self.initialize_graph() - def initialize_agent_info(self) -> None: - # 各エージェントに対してノードを作成 - for agent_def in self.ci_params.agent_def_list: - agent = agent_def.agent_executor - agent_name = agent_def.agent_name - agent_role = agent_def.agent_role - - self.node_descriptions_dict[agent_name] = agent_role - self.node_agent_dict[agent_name] = agent - - # グラフで使用する変数(状態)を定義 - class GraphState(TypedDict): - # llm_bind_tool: BaseLLM # ツールが紐付けされたllmモデル - # emb_model: HuggingFaceEmbeddings # Embeddingsモデル - question: str # 質問文 - # documents: List[Document] # indexから取得したドキュメントのリスト - messages: Annotated[Sequence[BaseMessage], operator.add] = [] - # intermediate_steps: str = "" + # メッセージ変更関数の準備 + def _modify_messages(self, messages: list[AnyMessage]): + show_callback_info("_modify_messages=", "messages", messages) + last_message = messages[0] + return [last_message] def initialize_graph(self) -> StateGraph: - workflow = StateGraph(CodeInterpreterStateGraph.GraphState) - + workflow = StateGraph(GraphState) + + agent_nodes = create_agent_nodes(self.ci_params) + is_first = True + for i, agent_node in enumerate(agent_nodes): + agent_name = f"agent{i}" + workflow.add_node(agent_name, agent_node) + # エージェントの実行後、即座に終了 + workflow.add_edge(agent_name, END) + if is_first: + workflow.set_entry_point(agent_name) + is_first = False + break + + return workflow.compile() + + def initialize_graph2(self) -> StateGraph: + workflow = StateGraph(GraphState) SUPERVISOR_AGENT_NAME = "supervisor_agent" supervisor_node_replaced = functools.partial( supervisor_node, supervisor=self.ci_params.supervisor_agent, name=SUPERVISOR_AGENT_NAME @@ -123,8 +143,8 @@ def run(self, input_data): def test(): - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) _ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params) planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params) _ = CodeInterpreterSupervisor.choose_supervisor(planner=planner, ci_params=ci_params) diff --git a/src/codeinterpreterapi/graphs/tool_node/tool_node.py b/src/codeinterpreterapi/graphs/tool_node/tool_node.py new file mode 100644 index 00000000..e567611a --- /dev/null +++ b/src/codeinterpreterapi/graphs/tool_node/tool_node.py @@ -0,0 +1,91 @@ +from typing import Sequence + +from langchain.agents import AgentExecutor +from langchain.chat_models.base import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.tools import BaseTool, tool +from langgraph.graph import MessageGraph, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + +from codeinterpreterapi.agents.agents import CodeInterpreterAgent +from codeinterpreterapi.brain.params import CodeInterpreterParams +from codeinterpreterapi.graphs.agent_wrapper_tool import AgentWrapperTool +from codeinterpreterapi.llm.llm import prepare_test_llm +from codeinterpreterapi.planners.planners import CodeInterpreterPlanner +from codeinterpreterapi.supervisors.supervisors import CodeInterpreterSupervisor +from codeinterpreterapi.test_prompts.test_prompt import TestPrompt + + +@tool +def divide(a: float, b: float) -> int: + """Return a / b.""" + return a / b + + +def create_agent_nodes(ci_params: CodeInterpreterParams): + agent_nodes = [] + for agent_def in ci_params.agent_def_list: + agent_executor = agent_def.agent_executor + agent_node = create_agent_node(agent_executor) + agent_nodes.append(agent_node) + return agent_nodes + + +def create_agent_node(agent_executor: AgentExecutor): + def agent_function(state, context): + """ + AgentExecutorを実行し、結果を状態に追加する関数。 + + Args: + state (dict): 現在の状態。messagesキーを含む必要がある。 + context (dict): 実行コンテキスト。 + + Returns: + dict: 更新された状態。 + """ + # stateから必要な情報を取得 + messages = state["messages"] + + # AgentExecutorを実行 + result = agent_executor.invoke({"input": messages[-1].content}) + + # 結果を新しいメッセージとして追加 + new_message = AIMessage(content=result["output"]) + state["messages"] = messages + [new_message] + print(state) + return state + + return ToolNode([agent_function]) + + +def create_tool_node( + state_graph: StateGraph, + llm: BaseChatModel, + tools: Sequence[BaseTool], +) -> StateGraph: + state_graph.add_node("tools", ToolNode(tools)) + state_graph.add_node("chatbot", llm.bind_tools(tools)) + state_graph.add_edge("tools", "chatbot") + state_graph.add_conditional_edges("chatbot", tools_condition) + return state_graph + + +def test(): + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) + _ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params) + planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params) + _ = CodeInterpreterSupervisor.choose_supervisor(planner=planner, ci_params=ci_params) + + state_graph = MessageGraph() + tools = AgentWrapperTool.create_agent_wrapper_tools(ci_params) + # state_graph = create_tool_node(state_graph, llm_tools, tools) + state_graph = create_agent_node(state_graph, llm_tools, tools) + state_graph.set_entry_point("chatbot") + compiled_graph = state_graph.compile() + result = compiled_graph.invoke([("user", TestPrompt.svg_input_str)]) + print(result[-1].content) + + +if __name__ == "__main__": + test() diff --git a/src/codeinterpreterapi/llm/llm.py b/src/codeinterpreterapi/llm/llm.py index 5ae8a33b..df59a94b 100644 --- a/src/codeinterpreterapi/llm/llm.py +++ b/src/codeinterpreterapi/llm/llm.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from dotenv import load_dotenv from google.ai.generativelanguage_v1beta.types import GenerateContentRequest from google.generativeai.types.content_types import FunctionDeclarationType # type: ignore[import] -from langchain.callbacks import StdOutCallbackHandler from langchain.chat_models.base import BaseChatModel from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable, RunnableConfig @@ -10,6 +10,7 @@ from langchain_google_genai._common import SafetySettingDict from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDictLike +from codeinterpreterapi.callbacks.callbacks import FullOutCallbackHandler from codeinterpreterapi.config import settings @@ -149,10 +150,14 @@ def get_llms(cls, model: str = settings.MODEL_LOCAL) -> List[BaseChatModel]: def prepare_test_llm(): + load_dotenv(verbose=True, override=False) + llm = CodeInterpreterLlm.get_llm_switcher() llm_tools = CodeInterpreterLlm.get_llm_switcher_tools() - runnable_config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]}) - llm = llm.with_config(runnable_config) - llm_tools = llm_tools.with_config(runnable_config) + callbacks = [FullOutCallbackHandler()] + configurable = {"session_id": "123"} + runnable_config = RunnableConfig(callbacks=callbacks, configurable=configurable) + llm = llm.with_config(config=runnable_config) + llm_tools = llm_tools.with_config(config=runnable_config) - return llm, llm_tools + return (llm, llm_tools, runnable_config) diff --git a/src/codeinterpreterapi/planners/planners.py b/src/codeinterpreterapi/planners/planners.py index ac28f23e..eeb725db 100644 --- a/src/codeinterpreterapi/planners/planners.py +++ b/src/codeinterpreterapi/planners/planners.py @@ -96,8 +96,8 @@ def get_prompt(): def test(): sample = "ステップバイステップで2*5+2を計算して。" - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params) result = planner.invoke({"input": sample, "agent_scratchpad": ""}) print("result=", result) diff --git a/src/codeinterpreterapi/supervisors/supervisors.py b/src/codeinterpreterapi/supervisors/supervisors.py index 829916ec..56734250 100644 --- a/src/codeinterpreterapi/supervisors/supervisors.py +++ b/src/codeinterpreterapi/supervisors/supervisors.py @@ -144,6 +144,12 @@ def route(next_action: str) -> str: ci_params.supervisor_agent = supervisor_agent_structured_output + # config + if ci_params.runnable_config: + supervisor_agent_structured_output = supervisor_agent_structured_output.with_config( + ci_params.runnable_config + ) + # agent_executor # agent_executor = AgentExecutor.from_agent_and_tools( # agent=supervisor_agent_for_executor, @@ -156,8 +162,8 @@ def route(next_action: str) -> str: def test(): - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) _ = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params) planner = CodeInterpreterPlanner.choose_planner(ci_params=ci_params) supervisor = CodeInterpreterSupervisor.choose_supervisor(planner=planner, ci_params=ci_params) diff --git a/src/codeinterpreterapi/thoughts/checker.py b/src/codeinterpreterapi/thoughts/checker.py index d5873cc3..ff9c0a74 100644 --- a/src/codeinterpreterapi/thoughts/checker.py +++ b/src/codeinterpreterapi/thoughts/checker.py @@ -241,7 +241,7 @@ def get_thought_validity(self, thought_validity) -> ThoughtValidity: # Testing the MyChecker class above: ####### def test_checker(): - llm, llm_tools = prepare_test_llm() + llm, llm_tools, runnable_config = prepare_test_llm() tot_chain = create_tot_chain_from_llm(llm) checker = tot_chain.checker assert ( diff --git a/src/codeinterpreterapi/thoughts/thoughts.py b/src/codeinterpreterapi/thoughts/thoughts.py index 96c12e1f..a281ac1c 100644 --- a/src/codeinterpreterapi/thoughts/thoughts.py +++ b/src/codeinterpreterapi/thoughts/thoughts.py @@ -52,8 +52,8 @@ def get_runnable_tot_chain(cls, ci_params: CodeInterpreterParams, is_simple: boo def test(): - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) tot_chain = CodeInterpreterToT.get_runnable_tot_chain(ci_params=ci_params) tot_chain.invoke({"input": sample2}) diff --git a/src/codeinterpreterapi/tools/code_checker.py b/src/codeinterpreterapi/tools/code_checker.py index 82d9f080..63137a6c 100644 --- a/src/codeinterpreterapi/tools/code_checker.py +++ b/src/codeinterpreterapi/tools/code_checker.py @@ -40,8 +40,8 @@ async def _aget_latest_code(self): def test(): settings.WORK_DIR = "/tmp" - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) tools_instance = CodeChecker(ci_params=ci_params) result = tools_instance._get_latest_code() print("result=", result) diff --git a/src/codeinterpreterapi/tools/python.py b/src/codeinterpreterapi/tools/python.py index 133cd981..986929ca 100644 --- a/src/codeinterpreterapi/tools/python.py +++ b/src/codeinterpreterapi/tools/python.py @@ -175,8 +175,8 @@ async def ashow_code(self, code: str) -> None: def test(): settings.WORK_DIR = "/tmp" - llm, llm_tools = prepare_test_llm() - ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools) + llm, llm_tools, runnable_config = prepare_test_llm() + ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config) tools_instance = PythonTools(ci_params=ci_params) test_code = "print('test output')" result = tools_instance._run_handler(test_code)