Skip to content

Commit

Permalink
fix: add MarkdownFileCallbackHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Aug 10, 2024
1 parent 6fabd38 commit 5e1bda2
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 88 deletions.
63 changes: 63 additions & 0 deletions src/codeinterpreterapi/callbacks/markdown/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import datetime
import os
from typing import Any, Dict, List

from langchain.callbacks import FileCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult


class MarkdownFileCallbackHandler(FileCallbackHandler):
def __init__(self, filename: str = "langchain_log.md"):
if os.path.isfile(filename):
os.remove(filename)
super().__init__(filename, "a")
self.step_count = 0

def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.step_count += 1
self._write_to_file(f"## Step {self.step_count}: LLM Start\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
self._write_to_file("**Prompts:**\n\n")
for i, prompt in enumerate(prompts, 1):
self._write_to_file(f"```\nPrompt {i}:\n{prompt}\n```\n\n")

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._write_to_file("**LLM Response:**\n\n")
for generation in response.generations[0]:
self._write_to_file(f"```\n{generation.text}\n```\n\n")
self._write_to_file("---\n\n")

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
self.step_count += 1
chain_name = serialized.get("name", "Unknown Chain")
self._write_to_file(f"## Step {self.step_count}: Chain Start - {chain_name}\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
self._write_to_file("**Inputs:**\n\n")
self._write_to_file(f"```\n{inputs}\n```\n\n")

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
self._write_to_file("**Outputs:**\n\n")
self._write_to_file(f"```\n{outputs}\n```\n\n")
self._write_to_file("---\n\n")

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
self.step_count += 1
self._write_to_file(f"## Step {self.step_count}: Agent Action\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
self._write_to_file(f"**Tool:** {action.tool}\n\n")
self._write_to_file("**Tool Input:**\n\n")
self._write_to_file(f"```\n{action.tool_input}\n```\n\n")

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
self._write_to_file("## Agent Finish\n\n")
self._write_to_file(f"**Timestamp:** {self._get_timestamp()}\n\n")
self._write_to_file("**Output:**\n\n")
self._write_to_file(f"```\n{finish.return_values}\n```\n\n")
self._write_to_file("---\n\n")

def _write_to_file(self, text: str) -> None:
self.file.write(text)
self.file.flush()

def _get_timestamp(self) -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import sys
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from uuid import UUID

from langchain.callbacks import StdOutCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage

from codeinterpreterapi.callbacks.util import show_callback_info


class CustomStdOutCallbackHandler(StdOutCallbackHandler):
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
Expand All @@ -25,80 +25,6 @@ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[Ba
print(f"\n\n\033[1m> Entering new {class_name} on_chat_model_start...\033[0m") # noqa: T201


def get_current_function_name(depth: int = 1) -> str:
return sys._getframe(depth).f_code.co_name


def show_callback_info(name: str, tag: str, data: Any) -> None:
current_function_name = get_current_function_name(2)
print("show_callback_info current_function_name=", current_function_name, name)
print(f"{tag}=", trim_data(data))


def trim_data(data: Union[Any, List[Any], Dict[str, Any]]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:data: 対象データ
"""
data_copy = deepcopy(data)
return trim_data_iter("", data_copy)


def trim_data_iter(indent: str, data: Union[Any, List[Any], Dict[str, Any]]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param data: 対象データ
"""
indent_next = indent + " "
if isinstance(data, dict):
return trim_data_dict(indent_next, data)
elif isinstance(data, list):
return trim_data_array(indent_next, data)
else:
return trim_data_other(indent, data)


def trim_data_dict(indent: str, data: Dict[str, Any]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
new_data_list = []
for k, v in data.items():
new_data_list.append(f"{indent}dict[{k}]: " + trim_data_iter(indent, v))
return "\n".join(new_data_list)


def trim_data_array(indent: str, data: List[Any]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
new_data_list = []
for i, item in enumerate(data):
print(f"{indent}array[{str(i)}]: ")
new_data_list.append(trim_data_iter(indent, item))
return "\n".join(new_data_list)


def trim_data_other(indent: str, data: Any) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
stype = str(type(data))
s = str(data)
return f"{indent}type={stype}, data={s[:80]}"


class FullOutCallbackHandler(CustomStdOutCallbackHandler):
# CallbackManagerMixin
def on_llm_start(
Expand Down
77 changes: 77 additions & 0 deletions src/codeinterpreterapi/callbacks/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import sys
from copy import deepcopy
from typing import Any, Dict, List, Union


def get_current_function_name(depth: int = 1) -> str:
return sys._getframe(depth).f_code.co_name


def show_callback_info(name: str, tag: str, data: Any) -> None:
current_function_name = get_current_function_name(2)
print("show_callback_info current_function_name=", current_function_name, name)
print(f"{tag}=", trim_data(data))


def trim_data(data: Union[Any, List[Any], Dict[str, Any]]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:data: 対象データ
"""
data_copy = deepcopy(data)
return trim_data_iter("", data_copy)


def trim_data_iter(indent: str, data: Union[Any, List[Any], Dict[str, Any]]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param data: 対象データ
"""
indent_next = indent + " "
if isinstance(data, dict):
return trim_data_dict(indent_next, data)
elif isinstance(data, list):
return trim_data_array(indent_next, data)
else:
return trim_data_other(indent, data)


def trim_data_dict(indent: str, data: Dict[str, Any]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
new_data_list = []
for k, v in data.items():
new_data_list.append(f"{indent}dict[{k}]: " + trim_data_iter(indent, v))
return "\n".join(new_data_list)


def trim_data_array(indent: str, data: List[Any]) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
new_data_list = []
for i, item in enumerate(data):
print(f"{indent}array[{str(i)}]: ")
new_data_list.append(trim_data_iter(indent, item))
return "\n".join(new_data_list)


def trim_data_other(indent: str, data: Any) -> str:
"""
dataの構造をデバッグ表示用に短縮する関数
:param indent: インデント文字列
:param data: 対象データ
"""
stype = str(type(data))
s = str(data)
return f"{indent}type={stype}, data={s[:80]}"
2 changes: 1 addition & 1 deletion src/codeinterpreterapi/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from codeinterpreterapi.agents.agents import CodeInterpreterAgent
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.callbacks.callbacks import show_callback_info
from codeinterpreterapi.callbacks.util import show_callback_info
from codeinterpreterapi.graphs.tool_node.tool_node import create_agent_nodes
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.planners.planners import CodeInterpreterPlanner
Expand Down
5 changes: 3 additions & 2 deletions src/codeinterpreterapi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from langchain_google_genai._common import SafetySettingDict
from langchain_google_genai._function_utils import _ToolConfigDict

from codeinterpreterapi.callbacks.callbacks import FullOutCallbackHandler
from codeinterpreterapi.callbacks.markdown.callbacks import MarkdownFileCallbackHandler
from codeinterpreterapi.callbacks.stdout.callbacks import FullOutCallbackHandler
from codeinterpreterapi.config import settings


Expand Down Expand Up @@ -159,7 +160,7 @@ def prepare_test_llm(is_smart: bool = False):
else:
llm = CodeInterpreterLlm.get_llm_switcher()
llm_tools = CodeInterpreterLlm.get_llm_switcher_tools()
callbacks = [FullOutCallbackHandler()]
callbacks = [FullOutCallbackHandler(), MarkdownFileCallbackHandler("langchain_log_test.md")]
configurable = {"session_id": "123"}
runnable_config = RunnableConfig(callbacks=callbacks, configurable=configurable)
llm = llm.with_config(config=runnable_config)
Expand Down
8 changes: 7 additions & 1 deletion src/codeinterpreterapi/planners/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,20 @@ def choose_planner(ci_params: CodeInterpreterParams) -> Union[Runnable, AgentExe
runnable = prompt | structured_llm
ci_params.planner_agent = runnable

# config
if ci_params.runnable_config:
runnable = runnable.with_config(ci_params.runnable_config)

# agent
agent = RunnableAgent(runnable=runnable, input_keys=list(prompt.input_variables))

# return executor or runnable
return_as_executor = False
if return_as_executor:
# TODO: handle step by step by original OutputParser
agent_executor = AgentExecutor(agent=agent, tools=ci_params.tools, verbose=ci_params.verbose)
agent_executor = AgentExecutor(
agent=agent, tools=ci_params.tools, verbose=ci_params.verbose, callbacks=ci_params.callbacks
)
return agent_executor

else:
Expand Down
4 changes: 3 additions & 1 deletion src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from codeinterpreterapi.brain.brain import CodeInterpreterBrain
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.callbacks.markdown.callbacks import MarkdownFileCallbackHandler
from codeinterpreterapi.chains import aremove_download_link
from codeinterpreterapi.chat_history import CodeBoxChatMessageHistory
from codeinterpreterapi.config import settings
Expand Down Expand Up @@ -101,7 +102,8 @@ def __init__(
init_session_id = "12345678-1234-1234-1234-123456789abc"
configurable = {"session_id": init_session_id} # TODO: set session_id
runnable_config = RunnableConfig(
configurable=configurable, callbacks=[AgentCallbackHandler(self._output_handler)]
configurable=configurable,
callbacks=[AgentCallbackHandler(self._output_handler), MarkdownFileCallbackHandler("langchain_log.md")],
)

# ci_params = {}
Expand Down
13 changes: 7 additions & 6 deletions src/codeinterpreterapi/supervisors/supervisors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,16 @@ class RouteSchema(BaseModel):
# agent
# TODO: use RouteSchema to determine use crew or agent or agent executor
# llm_with_structured_output = self.ci_params.llm.with_structured_output(RouteSchema)
supervisor_agent = prompt | self.ci_params.llm
runnable = prompt | self.ci_params.llm

# config
if self.ci_params.runnable_config:
runnable = runnable.with_config(self.ci_params.runnable_config)

# supervisor_agent_for_executor = prompt | ci_params.llm
# self.supervisor_chain = self.planner | prompt | llm_with_structured_output
self.supervisor_chain = self.planner
self.ci_params.supervisor_agent = supervisor_agent

# config
# if self.ci_params.runnable_config:
# self.supervisor_chain = self.supervisor_chain.with_config(self.ci_params.runnable_config)
self.ci_params.supervisor_agent = runnable

# supervisor_chain_no_agent
self.supervisor_chain_no_agent = self.ci_params.llm
Expand Down

0 comments on commit 5e1bda2

Please sign in to comment.