Skip to content

Commit

Permalink
fix: use generate_response_stream for agent response
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Aug 7, 2024
1 parent 5ad4283 commit 94ebe19
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 22 deletions.
3 changes: 2 additions & 1 deletion src/codeinterpreterapi/agents/structured_chat/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.utils.runnable import create_complement_input
from codeinterpreterapi.utils.runnable_history import assign_runnable_history


def create_structured_chat_agent_wrapper(
Expand Down Expand Up @@ -171,7 +172,7 @@ def create_structured_chat_agent(
| llm_with_stop
| output_parser
)
# agent = assign_runnable_history(agent, runnable_config)
agent = assign_runnable_history(agent, runnable_config)
return agent


Expand Down
3 changes: 2 additions & 1 deletion src/codeinterpreterapi/agents/tool_calling/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from codeinterpreterapi.brain.params import CodeInterpreterParams
from codeinterpreterapi.llm.llm import prepare_test_llm
from codeinterpreterapi.utils.runnable import create_complement_input
from codeinterpreterapi.utils.runnable_history import assign_runnable_history


def create_tool_calling_agent_wrapper(
Expand Down Expand Up @@ -106,7 +107,7 @@ def magic_function(input: int) -> int:
| llm_with_tools
| output_parser
)
# agent = assign_runnable_history(agent, runnable_config)
agent = assign_runnable_history(agent, runnable_config)
return agent


Expand Down
2 changes: 2 additions & 0 deletions src/codeinterpreterapi/brain/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
from uuid import UUID

from codeboxapi import CodeBox # type: ignore
from gui_agent_loop_core.schema.agent.schema import AgentDefinition
Expand Down Expand Up @@ -27,6 +28,7 @@ def test_multiply(first: str, second: str) -> str:

class CodeInterpreterParams(BaseModel):
codebox: Optional[CodeBox] = None
session_id: Optional[UUID] = None
llm_lite: Optional[Runnable] = None
llm_fast: Optional[Runnable] = None
llm_smart: Optional[Runnable] = None
Expand Down
96 changes: 76 additions & 20 deletions src/codeinterpreterapi/session.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import re
import traceback
from types import TracebackType
from typing import Any, Optional, Type
from typing import Any, Dict, Iterator, List, Optional, Type, Union
from uuid import UUID

from codeboxapi import CodeBox # type: ignore
from gui_agent_loop_core.schema.message.schema import GuiAgentInterpreterChatResponseStr
from langchain.callbacks.base import Callbacks
from langchain.callbacks.base import BaseCallbackHandler, Callbacks
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain_community.chat_message_histories.redis import RedisChatMessageHistory
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages.base import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool

Expand All @@ -33,6 +34,48 @@ def _handle_deprecated_kwargs(kwargs: dict) -> None:
settings.MAX_ITERATIONS = kwargs.get("max_iterations", settings.MAX_ITERATIONS)


class AgentCallbackHandler(BaseCallbackHandler):
"""Base callback handler that can be used to handle callbacks from langchain."""

def __init__(self, agent_callback_func: callable):
self.agent_callback_func = agent_callback_func

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain starts running."""
print("AgentCallbackHandler on_chain_start")

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
"""Run when chain ends running."""
print("AgentCallbackHandler on_chain_end")
self.agent_callback_func(outputs)

def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
print("AgentCallbackHandler on_chat_model_start")

def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
"""Run when chain errors."""
print("AgentCallbackHandler on_chain_error")

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
print("AgentCallbackHandler on_agent_action")

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
print("AgentCallbackHandler on_agent_finish")


class CodeInterpreterSession:
def __init__(
self,
Expand All @@ -54,8 +97,11 @@ def __init__(
llm_tools: Runnable = CodeInterpreterLlm.get_llm_switcher_tools()

# runnable_config
configurable = {"session_id": "123"} # TODO: set session_id
runnable_config = RunnableConfig(configurable=configurable)
init_session_id = "12345678-1234-1234-1234-123456789abc"
configurable = {"session_id": init_session_id} # TODO: set session_id
runnable_config = RunnableConfig(
configurable=configurable, callbacks=[AgentCallbackHandler(self._output_handler)]
)

# ci_params = {}
self.ci_params = CodeInterpreterParams(
Expand All @@ -72,6 +118,7 @@ def __init__(
is_ja=is_ja,
runnable_config=runnable_config,
)
self.ci_params.session_id = UUID(init_session_id)
self.brain = CodeInterpreterBrain(self.ci_params)
self.log("llm=" + str(llm))

Expand All @@ -83,11 +130,12 @@ def __init__(
def from_id(cls, session_id: UUID, **kwargs: Any) -> "CodeInterpreterSession":
session = cls(**kwargs)
session.ci_params.codebox = CodeBox.from_id(session_id)
session.ci_params.session_id = session_id
return session

@property
def session_id(self) -> Optional[UUID]:
return self.ci_params.codebox.session_id
return self.ci_params.session_id

def start(self) -> SessionStatus:
print("start")
Expand Down Expand Up @@ -181,7 +229,7 @@ def _output_handler_pre(self, response: Any) -> str:
print("generate_response brain.invoke output_str=", output_str)
return output_str

def _output_handler(self, final_response: str) -> CodeInterpreterResponse:
def _output_handler_post(self, final_response: str) -> CodeInterpreterResponse:
"""Embed images in the response"""
for file in self.output_files:
if str(file.name) in final_response:
Expand All @@ -206,6 +254,12 @@ def _output_handler(self, final_response: str) -> CodeInterpreterResponse:
)
return response

def _output_handler(self, response: Any) -> CodeInterpreterResponse:
"""Embed images in the response"""
final_response = self._output_handler_pre(response)
response = self._output_handler_post(final_response)
return response

async def _aoutput_handler(self, final_response: str) -> CodeInterpreterResponse:
"""Embed images in the response"""
for file in self.output_files:
Expand Down Expand Up @@ -254,8 +308,7 @@ def generate_response(
# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
response = self.brain.invoke(input=input_message)
# ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #=======
output_str = self._output_handler_pre(response)
return self._output_handler(output_str)
return self._output_handler(response)
except Exception as e:
traceback_str = "\n"
if self.verbose:
Expand Down Expand Up @@ -308,36 +361,39 @@ def generate_response_stream(
self,
user_msg: str,
files: list[File] = None,
) -> GuiAgentInterpreterChatResponseStr:
) -> Iterator[str]:
"""Generate a Code Interpreter response based on the user's input."""
if files is None:
files = []
user_request = UserRequest(content=user_msg, files=files)
try:
self._input_handler(user_request)
agent_scratchpad = ""
input_message = {"input": user_request.content, "agent_scratchpad": agent_scratchpad}
print("llm stream start")
# ======= ↓↓↓↓ LLM invoke ↓↓↓↓ #=======
response = self.brain.stream(input=user_request.content)
response_stream = self.brain.stream(input=input_message)
# ======= ↑↑↑↑ LLM invoke ↑↑↑↑ #=======
print("llm stream response(type)=", type(response))
print("llm stream response=", response)
print("llm stream response(type)=", type(response_stream))

full_output = ""
for chunk in response:
yield chunk
full_output += chunk["output"]
for chunk in response_stream:
if isinstance(chunk, dict) and "output" in chunk:
output = chunk["output"]
else:
output = str(chunk)
yield output
full_output += output

print("generate_response_stream brain.stream full_output=", full_output)
self._aoutput_handler(full_output)
except Exception as e:
if self.verbose:
traceback.print_exc()
if settings.DETAILED_ERROR:
yield "Error in CodeInterpreterSession(generate_response_stream): " f"{e.__class__.__name__} - {e}"
yield f"Error in CodeInterpreterSession(generate_response_stream): {e.__class__.__name__} - {e}"
else:
yield (
"Sorry, something went while generating your response." + "Please try again or restart the session."
)
yield "Sorry, something went wrong while generating your response. Please try again or restart the session."

async def agenerate_response_stream(
self,
Expand Down

0 comments on commit 94ebe19

Please sign in to comment.