-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_api.py
229 lines (189 loc) · 8.35 KB
/
test_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import asyncio
import json
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import geopandas as gpd
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, START, StateGraph
from langgraph.prebuilt import ToolNode
from typing_extensions import TypedDict
from dotenv import load_dotenv
load_dotenv()
# Connection Manager for WebSocket clients
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.thread_states: Dict[str, Any] = {}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
if client_id in self.thread_states:
del self.thread_states[client_id]
async def send_message(self, client_id: str, message: dict):
if client_id in self.active_connections:
await self.active_connections[client_id].send_json(message)
def set_thread_state(self, client_id: str, state: Any):
self.thread_states[client_id] = state
def get_thread_state(self, client_id: str) -> Optional[Any]:
return self.thread_states.get(client_id)
class MessageType(str, Enum):
QUERY = "query"
HUMAN_INPUT = "human_input"
RESULT = "result"
ERROR = "error"
@dataclass
class WebSocketMessage:
type: MessageType
content: dict
client_id: str
app = FastAPI()
@app.get("/")
def serve_root():
with open("frontend/index.html") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
manager = ConnectionManager()
# llm = ChatAnthropic(model="claude-3-5-sonnet-20240620")
llm = ChatOllama(model="qwen2.5:7b")
@tool
def location(query: str):
"Returns location of a place"
match_df = gpd.read_file(
"data/gadm_410_small.gpkg",
where=f"name like '%{query}%'")
return match_df.to_json()
@tool
def weather(query: str):
"Retuns weather of a place"
return f"The weather of {query} is hot & humid."
tools = [location, weather]
llm = llm.bind_tools(tools)
def should_continue(state):
last_msg = state["messages"][-1]
if not last_msg.tool_calls:
return "end"
else:
return "continue"
def call_model(state):
msgs = state["messages"]
r = llm.invoke(msgs)
return {"messages": [r]}
# Setup LangGraph workflow
tool_node = ToolNode(tools)
wf = StateGraph(MessagesState)
wf.add_node("agent", call_model)
wf.add_node("action", tool_node)
wf.add_edge(START, "agent")
wf.add_conditional_edges(
"agent",
should_continue,
{"end": END, "continue": "action"}
)
wf.add_edge("action", "agent")
memory = MemorySaver()
graph = wf.compile(checkpointer=memory, interrupt_after=["action"])
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket, client_id)
try:
while True:
raw_data = await websocket.receive_text()
try:
data = json.loads(raw_data)
message = WebSocketMessage(**data)
if message.type == MessageType.QUERY:
# Initialize new graph execution
thread = {"configurable": {"thread_id": client_id}}
inputs = [HumanMessage(content=message.content["query"])]
try:
async for event in graph.astream({"messages": inputs}, thread, stream_mode="values"):
last_message = event["messages"][-1]
if isinstance(last_message, ToolMessage):
# Store current state for HITL
manager.set_thread_state(client_id, {
"thread": thread,
"tool_message": last_message
})
# If it's a location tool, send options to client
if last_message.name == "location":
options = gpd.read_file(last_message.content, driver="GeoJSON")
locations = [
{"id": idx, "name": row["name"]}
for idx, row in options.iterrows()
]
await manager.send_message(client_id, {
"type": "options",
"tool": "location",
"options": locations
})
break # Wait for human input
await manager.send_message(client_id, {
"type": "update",
"content": last_message.content
})
except Exception as e:
await manager.send_message(client_id, {
"type": "error",
"content": str(e)
})
elif message.type == MessageType.HUMAN_INPUT:
# Handle human input and continue graph execution
thread_state = manager.get_thread_state(client_id)
if not thread_state:
await manager.send_message(client_id, {
"type": "error",
"content": "No active state found"
})
continue
tool_message = thread_state["tool_message"]
thread = thread_state["thread"]
if tool_message.name == "location":
selected_idx = message.content["selected_index"]
options = gpd.read_file(tool_message.content, driver="GeoJSON")
selected_row = options.iloc[selected_idx]
# Update the tool message content
tool_message.content = f"{selected_row['name']} is located in south east india, in Odisha."
graph.update_state(thread, {"messages": tool_message})
# Send map update - selected_row is a GeoDataFrame row with geometry
geometry = selected_row.geometry.__geo_interface__
geojson_feature = {
"type": "Feature",
"properties": {
"name": selected_row['name'] # or any attributes you want
},
"geometry": geometry
}
# geojson_feature = json.loads(selected_row.geometry.to_json())
await manager.send_message(client_id, {
"type": "map_update",
"geojson": geojson_feature
})
# Continue graph execution
async for event in graph.astream(None, thread, stream_mode="values"):
last_message = event["messages"][-1]
await manager.send_message(client_id, {
"type": "update",
"content": last_message.content
})
except json.JSONDecodeError:
await manager.send_message(client_id, {
"type": "error",
"content": "Invalid JSON message"
})
except Exception as e:
await manager.send_message(client_id, {
"type": "error",
"content": str(e)
})
except WebSocketDisconnect:
manager.disconnect(client_id)