Skip to content

Commit

Permalink
fix: add langgraph try in working
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Jul 15, 2024
1 parent cbec7ab commit 26af5e0
Show file tree
Hide file tree
Showing 21 changed files with 496 additions and 63 deletions.
11 changes: 3 additions & 8 deletions src/codeinterpreterapi/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,18 @@ def create_single_chat_agent_executor(ci_params: CodeInterpreterParams) -> Agent
# return_messages=True,
# chat_memory=chat_memory,
# ),
callbacks=ci_params.callbacks,
# callbacks=ci_params.callbacks,
)
print("agent_executor.input_keys", agent_executor.input_keys)
print("agent_executor.output_keys", agent_executor.output_keys)
return agent_executor


def test():
# sample = "lsコマンドを実行してください。"
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
ci_params.tools = []
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
ci_params.tools = CodeInterpreterTools(ci_params).get_all_tools()

agent_executors = CodeInterpreterAgent.choose_agent_executors(ci_params=ci_params)
# agent = CodeInterpreterAgent.choose_single_chat_agent(ci_params=ci_params)
# agent = CodeInterpreterAgent.create_agent_and_executor_experimental(ci_params=ci_params)
result = agent_executors[0].invoke({"input": TestPrompt.svg_input_str})
print("result=", result)

Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/agents/config/coding_agent/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
agent_definition:
agent_name: "main関数を作成するエージェント"
agent_name: "main_function_create_agent"
agent_type: tool_calling
agent_role: |
プログラムのmain関数を作成する作業をお願いします。
Expand All @@ -8,7 +8,7 @@ agent_definition:
あなたの役割は次の通りです。
・コードの作成
・コードの実行(toolを使って実行してエラーなく実行できるところまで確認してください)
・コードの実行(toolを使ってコードを実行できます。エラーなく実行できることを確認してください。)
・エラーがある場合の修正と再実行
次に続くシステムプロンプトを注意深く読んで正しくふるまってください。
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/agents/config/fix_agent/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
agent_definition:
agent_name: "プログラムを修正するエージェント"
agent_name: "code_fix_agent"
agent_type: structured_chat
agent_role: |
プログラムを修正する作業をお願いします。
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
agent_definition:
agent_name: "プログラムを分割するエージェント"
agent_name: "code_split_agent"
agent_type: structured_chat
agent_role: |
プログラムを分割する作業をお願いします。
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/agents/structured_chat/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def create_structured_chat_agent(


def test():
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
prompt = create_structured_chat_agent_prompt(ci_params.is_ja)
agent = create_structured_chat_agent_wrapper(ci_params=ci_params, prompt=prompt)
test_input = "pythonで円周率を表示するプログラムを実行してください。"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def load_structured_chat_agent_executor(


def test():
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
agent_executor = load_structured_chat_agent_executor(ci_params)
test_input = "pythonで円周率を表示するプログラムを実行してください。"
agent_executor_output = agent_executor.invoke({"input": test_input})
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/agents/tool_calling/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def magic_function(input: int) -> int:


def test():
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
prompt = create_tool_calling_agent_prompt(ci_params.is_ja)
agent = create_tool_calling_agent_wrapper(ci_params=ci_params, prompt=prompt)
test_input = "pythonで円周率を表示するプログラムを実行してください。"
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/agents/tool_calling/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def load_tool_calling_agent_executor(


def test():
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
ci_params.verbose = True
ci_params.verbose_prompt = False
agent_executor = load_tool_calling_agent_executor(ci_params)
Expand Down
4 changes: 2 additions & 2 deletions src/codeinterpreterapi/brain/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def use_agent(self, new_agent: AgentName):

def test():
settings.WORK_DIR = "/tmp"
llm, llm_tools = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools)
llm, llm_tools, runnable_config = prepare_test_llm()
ci_params = CodeInterpreterParams.get_test_params(llm=llm, llm_tools=llm_tools, runnable_config=runnable_config)
ci_params.tools = []
ci_params.tools = CodeInterpreterTools(ci_params).get_all_tools()
brain = CodeInterpreterBrain(ci_params)
Expand Down
254 changes: 254 additions & 0 deletions src/codeinterpreterapi/callbacks/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import sys
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from langchain.callbacks import StdOutCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage


class CustomStdOutCallbackHandler(StdOutCallbackHandler):
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we are entering a chain."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
print(f"inputs={inputs}")

def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Print out that we are entering a chain."""
print(f"on_tool_start input_str={input_str}")

def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
"""Run when Chat Model starts running."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} on_chat_model_start...\033[0m") # noqa: T201


def get_current_function_name(depth: int = 1) -> str:
return sys._getframe(depth).f_code.co_name


def show_callback_info(name: str, tag: str, data: Any) -> None:
current_function_name = get_current_function_name(2)
print("show_callback_info current_function_name=", current_function_name, name)
print(f"{tag}=", trim_data(data))


def trim_data(data: Union[Any, List[Any], Dict[str, Any]]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:data: 対象データ
"""
data_copy = deepcopy(data)
return trim_data_iter("", data_copy)


def trim_data_iter(indent: str, data: Union[Any, List[Any], Dict[str, Any]]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param data: 対象データ
"""
indent_next = indent + " "
if isinstance(data, dict):
return trim_data_dict(indent_next, data)
elif isinstance(data, list):
return trim_data_array(indent_next, data)
else:
return trim_data_other(indent, data)


def trim_data_dict(indent: str, data: Dict[str, Any]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
new_data_list = []
for k, v in data.items():
new_data_list.append(f"{indent}dict[{k}]: " + trim_data_iter(indent, v))
return "\n".join(new_data_list)


def trim_data_array(indent: str, data: List[Any]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
new_data_list = []
for i, item in enumerate(data):
print(f"{indent}array[{str(i)}]: ")
new_data_list.append(trim_data_iter(indent, item))
return "\n".join(new_data_list)


def trim_data_other(indent: str, data: Any) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
stype = str(type(data))
s = str(data)
return f"{indent}type={stype}, data={s[:80]}"


class FullOutCallbackHandler(CustomStdOutCallbackHandler):
# CallbackManagerMixin
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running.
**ATTENTION**: This method is called for non-chat models (regular LLMs). If
you're implementing a handler for a chat model,
you should use on_chat_model_start instead.
"""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
show_callback_info(class_name, "prompts", prompts)

def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running.
**ATTENTION**: This method is called for chat models. If you're implementing
a handler for a non-chat model, you should use on_llm_start instead.
"""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
show_callback_info(class_name, "messages", messages)

def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
show_callback_info(class_name, "query", query)

def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
show_callback_info(class_name, "inputs", inputs)

def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
show_callback_info(class_name, "input_str", input_str)

def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running."""
show_callback_info("no_name", "outputs", outputs)

def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors."""
show_callback_info("no_name", "error", error)

def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
show_callback_info("no_name", "action", action)

def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
show_callback_info("no_name", "finish", finish)

# ToolManagerMixin
def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool ends running."""
show_callback_info("no_name", "output", output)

def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors."""
show_callback_info("no_name", "error", error)
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/chains/modifications_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def aget_file_modifications(


async def test() -> None:
llm, llm_tools = prepare_test_llm()
llm, llm_tools, runnable_config = prepare_test_llm()

code = """
import matplotlib.pyplot as plt
Expand Down
Loading

0 comments on commit 26af5e0

Please sign in to comment.