Skip to content

Commit

Permalink
fix: src/codeinterpreterapi/session.py for CallbackHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Jan 26, 2025
1 parent cf28909 commit e6ed971
Showing 1 changed file with 175 additions and 25 deletions.
200 changes: 175 additions & 25 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import re
import traceback
from types import TracebackType
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Type, Union
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Type
from uuid import UUID

from codeboxapi import CodeBox # type: ignore
from codeboxapi.schema import CodeBoxStatus # type: ignore
from gui_agent_loop_core.schema.message.schema import BaseMessageContent
from langchain.callbacks.base import Callbacks

from langchain_core.callbacks import BaseCallbackHandler
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain_community.chat_message_histories.redis import RedisChatMessageHistory
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages.base import BaseMessage
Expand Down Expand Up @@ -44,18 +43,71 @@ class AgentCallbackHandler(BaseCallbackHandler):
"""Base callback handler that can be used to handle callbacks from langchain."""

def __init__(self, agent_callback_func: callable):
print("AgentCallbackHandler __init__ agent_callback_func=", agent_callback_func)
self.agent_callback_func = agent_callback_func

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain starts running."""
# print("AgentCallbackHandler on_chain_start")

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""
# print("AgentCallbackHandler on_chain_end type(outputs)=", type(outputs))
# print("AgentCallbackHandler on_chain_end type(outputs)=", outputs)
### on_chain callbacks ###
def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
print("AgentCallbackHandler on_chain_start run_id=", run_id)

def on_chain_end(
self,
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running.
Args:
outputs (Dict[str, Any]): The outputs of the chain.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
print("AgentCallbackHandler on_chain_end run_id=", run_id, ", type(outputs)=", type(outputs))
print("AgentCallbackHandler on_chain_end self.agent_callback_func=", self.agent_callback_func)
self.agent_callback_func(outputs)

def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors.
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
print("AgentCallbackHandler on_chain_error")

### on_chat callbacks ###
def on_chat_model_start(
self,
serialized: Dict[str, Any],
Expand All @@ -67,20 +119,115 @@ def on_chat_model_start(
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
# print("AgentCallbackHandler on_chat_model_start")
"""Run when a chain starts running.
Args:
serialized (Dict[str, Any]): The serialized chain.
inputs (Dict[str, Any]): The inputs.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
kwargs (Any): Additional keyword arguments.
"""
print("AgentCallbackHandler on_chat_model_start")

### on_agent callbacks ###

def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action.
Args:
action (AgentAction): The agent action.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
print("AgentCallbackHandler on_agent_action")

def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
"""Run when chain errors."""
# print("AgentCallbackHandler on_chain_error")
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on the agent end.
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
# print("AgentCallbackHandler on_agent_action")
Args:
finish (AgentFinish): The agent finish.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
print("AgentCallbackHandler on_agent_finish")

### on_tool callbacks ###
def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when the tool starts running.
Args:
serialized (Dict[str, Any]): The serialized tool.
input_str (str): The input string.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
metadata (Optional[Dict[str, Any]]): The metadata.
inputs (Optional[Dict[str, Any]]): The inputs.
kwargs (Any): Additional keyword arguments.
"""
print("AgentCallbackHandler on_tool_start")

def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when the tool ends running.
Args:
output (Any): The output of the tool.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
print("AgentCallbackHandler on_tool_end")

def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors.
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
# print("AgentCallbackHandler on_agent_finish")
Args:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
print("AgentCallbackHandler on_tool_error")


class CodeInterpreterSession:
Expand Down Expand Up @@ -313,12 +460,16 @@ def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse:

def _output_handler(self, response: Any) -> CodeInterpreterResponse:
"""Embed images in the response"""
print("XXXX _output_handler in response=", type(response))
final_response = self._output_handler_pre(response)
print("XXXX _output_handler step1 ")
response = self._output_handler_post(final_response)
print("XXXX _output_handler out ")
return response

async def _aoutput_handler(self, response: str) -> CodeInterpreterResponse:
"""Embed images in the response"""
print("XXXX _aoutput_handler in response=", type(response))
final_response = self._output_handler_pre(response)
for file in self.output_files:
if str(file.name) in final_response:
Expand Down Expand Up @@ -405,7 +556,7 @@ async def agenerate_response(
traceback.print_exc()
if settings.DETAILED_ERROR:
return CodeInterpreterResponse(
content="Error in CodeInterpreterSession(agenerate_response): " f"{e.__class__.__name__} - {e}",
content=f"Error in CodeInterpreterSession(agenerate_response): {e.__class__.__name__} - {e}",
agent_name=self.brain.current_agent,
)
else:
Expand Down Expand Up @@ -479,8 +630,7 @@ async def agenerate_response_stream(
traceback.print_exc()
if settings.DETAILED_ERROR:
yield CodeInterpreterResponse(
content="Error in CodeInterpreterSession(agenerate_response_stream): "
f"{e.__class__.__name__} - {e}"
content=f"Error in CodeInterpreterSession(agenerate_response_stream): {e.__class__.__name__} - {e}"
)
else:
yield CodeInterpreterResponse(
Expand Down

0 comments on commit e6ed971

Please sign in to comment.