diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 98b258a..c3dd274 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -1,19 +1,18 @@ import re import traceback from types import TracebackType -from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Type, Union +from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Type from uuid import UUID from codeboxapi import CodeBox # type: ignore from codeboxapi.schema import CodeBoxStatus # type: ignore from gui_agent_loop_core.schema.message.schema import BaseMessageContent from langchain.callbacks.base import Callbacks - -from langchain_core.callbacks import BaseCallbackHandler from langchain_community.chat_message_histories.in_memory import ChatMessageHistory from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory from langchain_community.chat_message_histories.redis import RedisChatMessageHistory from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.language_models import BaseLanguageModel from langchain_core.messages.base import BaseMessage @@ -44,18 +43,71 @@ class AgentCallbackHandler(BaseCallbackHandler): """Base callback handler that can be used to handle callbacks from langchain.""" def __init__(self, agent_callback_func: callable): + print("AgentCallbackHandler __init__ agent_callback_func=", agent_callback_func) self.agent_callback_func = agent_callback_func - def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any: - """Run when chain starts running.""" - # print("AgentCallbackHandler on_chain_start") - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: - """Run when chain ends running.""" - # print("AgentCallbackHandler on_chain_end type(outputs)=", type(outputs)) - # print("AgentCallbackHandler on_chain_end type(outputs)=", outputs) + ### on_chain callbacks ### + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Dict[str, Any]): The inputs. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ + print("AgentCallbackHandler on_chain_start run_id=", run_id) + + def on_chain_end( + self, + outputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain ends running. + + Args: + outputs (Dict[str, Any]): The outputs of the chain. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" + print("AgentCallbackHandler on_chain_end run_id=", run_id, ", type(outputs)=", type(outputs)) + print("AgentCallbackHandler on_chain_end self.agent_callback_func=", self.agent_callback_func) self.agent_callback_func(outputs) + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain errors. + + Args: + error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" + print("AgentCallbackHandler on_chain_error") + + ### on_chat callbacks ### def on_chat_model_start( self, serialized: Dict[str, Any], @@ -67,20 +119,115 @@ def on_chat_model_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - """Run when chain starts running.""" - # print("AgentCallbackHandler on_chat_model_start") + """Run when a chain starts running. + + Args: + serialized (Dict[str, Any]): The serialized chain. + inputs (Dict[str, Any]): The inputs. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + kwargs (Any): Additional keyword arguments. + """ + print("AgentCallbackHandler on_chat_model_start") + + ### on_agent callbacks ### + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action. + + Args: + action (AgentAction): The agent action. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" + print("AgentCallbackHandler on_agent_action") - def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: - """Run when chain errors.""" - # print("AgentCallbackHandler on_chain_error") + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on the agent end. - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - # print("AgentCallbackHandler on_agent_action") + Args: + finish (AgentFinish): The agent finish. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" + print("AgentCallbackHandler on_agent_finish") + + ### on_tool callbacks ### + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + """Run when the tool starts running. + + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input string. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + tags (Optional[List[str]]): The tags. + metadata (Optional[Dict[str, Any]]): The metadata. + inputs (Optional[Dict[str, Any]]): The inputs. + kwargs (Any): Additional keyword arguments. + """ + print("AgentCallbackHandler on_tool_start") + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when the tool ends running. + + Args: + output (Any): The output of the tool. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" + print("AgentCallbackHandler on_tool_end") + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool errors. - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - # print("AgentCallbackHandler on_agent_finish") + Args: + error (BaseException): The error that occurred. + run_id (UUID): The run ID. This is the ID of the current run. + parent_run_id (UUID): The parent run ID. This is the ID of the parent run. + kwargs (Any): Additional keyword arguments.""" + print("AgentCallbackHandler on_tool_error") class CodeInterpreterSession: @@ -313,12 +460,16 @@ def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse: def _output_handler(self, response: Any) -> CodeInterpreterResponse: """Embed images in the response""" + print("XXXX _output_handler in response=", type(response)) final_response = self._output_handler_pre(response) + print("XXXX _output_handler step1 ") response = self._output_handler_post(final_response) + print("XXXX _output_handler out ") return response async def _aoutput_handler(self, response: str) -> CodeInterpreterResponse: """Embed images in the response""" + print("XXXX _aoutput_handler in response=", type(response)) final_response = self._output_handler_pre(response) for file in self.output_files: if str(file.name) in final_response: @@ -405,7 +556,7 @@ async def agenerate_response( traceback.print_exc() if settings.DETAILED_ERROR: return CodeInterpreterResponse( - content="Error in CodeInterpreterSession(agenerate_response): " f"{e.__class__.__name__} - {e}", + content=f"Error in CodeInterpreterSession(agenerate_response): {e.__class__.__name__} - {e}", agent_name=self.brain.current_agent, ) else: @@ -479,8 +630,7 @@ async def agenerate_response_stream( traceback.print_exc() if settings.DETAILED_ERROR: yield CodeInterpreterResponse( - content="Error in CodeInterpreterSession(agenerate_response_stream): " - f"{e.__class__.__name__} - {e}" + content=f"Error in CodeInterpreterSession(agenerate_response_stream): {e.__class__.__name__} - {e}" ) else: yield CodeInterpreterResponse(