Skip to content

Commit

Permalink
Send tool start/end, tokens, and stop/result messages over websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein committed Dec 10, 2024
1 parent 65b2c13 commit d377906
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 26 deletions.
59 changes: 46 additions & 13 deletions chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,55 @@
from typing import Any, Dict, Optional, Union, List
from uuid import UUID
from typing import Any, Dict, List

from websocket import Websocket

from json.decoder import JSONDecodeError
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.outputs import LLMResult
from langchain.schema import AgentFinish, AgentAction

import ast
import json

class AgentHandler(BaseCallbackHandler):
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any:
socket: Websocket = metadata.get("socket")
def deserialize_input(input_str):
try:
return ast.literal_eval(input_str)
except (ValueError, SyntaxError):
try:
return json.loads(input_str)
except JSONDecodeError:
return input_str

class AgentHandler(BaseCallbackHandler):
def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict[str, Any]):
if socket is None:
raise ValueError("Socket not defined in agent handler via metadata")
raise ValueError("Socket not provided to agent callback handler")
self.socket = socket
self.ref = ref
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
content = response.generations[0][0].text
if content != "":
self.socket.send({"type": "stop", "ref": self.ref})
self.socket.send({"type": "answer", "ref": self.ref, "message": content})

socket.send({"type": "tool_start", "message": serialized})
def on_llm_new_token(self, token: str, **kwargs: Dict[str, Any]):
if token != "":
self.socket.send({"type": "token", "ref": self.ref, "message": token})

def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Dict[str, Any]) -> Any:
input = deserialize_input(input_str)
self.socket.send({"type": "tool_start", "ref": self.ref, "message": {"tool": serialized.get("name"), "input": input}})

def on_tool_end(self, output, **kwargs):
print("on_tool_end output:")
print(output)

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
case "aggregate":
self.socket.send({"type": "aggregation_result", "ref": self.ref, "message": output.artifact.get("aggregation_result", {})})
case "search":
try:
docs: List[Dict[str, Any]] = [doc.metadata for doc in output.artifact]
self.socket.send({"type": "source_documents", "ref": self.ref, "message": docs})
except json.decoder.JSONDecodeError as e:
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
case _:
print(f"Unhandled tool_end message: {output}")
11 changes: 2 additions & 9 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@
from typing import Literal

from agent.tools import search, aggregate
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages.base import BaseMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from setup import openai_chat_client
from websocket import Websocket


tools = [search, aggregate]

tool_node = ToolNode(tools)

model = openai_chat_client().bind_tools(tools)
model = openai_chat_client(streaming=True).bind_tools(tools)

# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
Expand All @@ -30,15 +27,11 @@ def should_continue(state: MessagesState) -> Literal["tools", END]:


# Define the function that calls the model
def call_model(state: MessagesState, config: RunnableConfig):
def call_model(state: MessagesState):
messages = state["messages"]
socket = config["configurable"].get("socket", None)
response: BaseMessage = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))
# We return a list, because this will get added to the existing list
# if socket is not none and the response content is not an empty string
if socket is not None and response.content != "":
print("Sending response to socket")
socket.send({"type": "answer", "message": response.content})
return {"messages": [response]}


Expand Down
5 changes: 2 additions & 3 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import secrets # noqa
import boto3
import json
import logging
import os
from datetime import datetime
from event_config import EventConfig
Expand Down Expand Up @@ -71,11 +70,11 @@ def handler(event, context):
logGroupName=log_group, logStreamName=log_stream, logEvents=log_events
)

callbacks = [AgentHandler()]
callbacks = [AgentHandler(config.socket, config.ref)]
try:
search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref, "socket": config.socket}, "callbacks": callbacks, "metadata": {"socket": config.socket}},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks},
debug=False
)
except Exception as e:
Expand Down
1 change: 0 additions & 1 deletion chat/src/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def send(self, data):
if self.connection_id == "debug":
print(data)
else:
print(f"Sending data to {self.connection_id}: {data}")
self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id)
return data

Expand Down

0 comments on commit d377906

Please sign in to comment.