Skip to content

Commit

Permalink
Begin to add LangGraph agent code
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein committed Dec 4, 2024
1 parent 973faea commit 6df7611
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 4 deletions.
Empty file added chat/src/agent/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions chat/src/agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os

from typing import Literal

from agent.tools import search, aggregate
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from setup import openai_chat_client

tools = [search, aggregate]

tool_node = ToolNode(tools)

model = openai_chat_client().bind_tools(tools)

# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then we route to the "tools" node
if last_message.tool_calls:
return "tools"
# Otherwise, we stop (reply to the user)
return END


# Define the function that calls the model
def call_model(state: MessagesState):
messages = state["messages"]
response = model.invoke(messages, model=os.getenv("AZURE_DEPLOYMENT_NAME"))
# We return a list, because this will get added to the existing list
return {"messages": [response]}


# Define a new graph
workflow = StateGraph(MessagesState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

# Set the entrypoint as `agent`
workflow.add_edge(START, "agent")

# Add a conditional edge
workflow.add_conditional_edges(
"agent",
should_continue,
)

# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

# Initialize memory to persist state between graph runs
checkpointer = MemorySaver()

# Compile the graph
app = workflow.compile(checkpointer=checkpointer, debug=True)
35 changes: 35 additions & 0 deletions chat/src/agent/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json

from langchain_core.tools import tool
from opensearch_client import opensearch_vector_store

@tool(response_format="content_and_artifact")
def search(query: str):
"""Perform a semantic search of Northwestern University Library digital collections. When answering a search query, ground your answer in the context of the results with references to the document's metadata."""
query_results = opensearch_vector_store.similarity_search(query, size=20)
return json.dumps(query_results, default=str), query_results

@tool(response_format="content_and_artifact")
def aggregate(aggregation_query: str):
"""
Perform a quantitative aggregation on the OpenSearch index.
Available fields:
api_link, api_model, ark, collection.title.keyword, contributor.label.keyword, contributor.variants,
create_date, creator.variants, date_created, embedding_model, embedding_text_length,
folder_name, folder_number, genre.variants, id, identifier, indexed_at, language.variants,
legacy_identifier, library_unit, location.variants, modified_date, notes.note, notes.type,
physical_description_material, physical_description_size, preservation_level, provenance, published, publisher,
related_url.url, related_url.label, representative_file_set.aspect_ratio, representative_file_set.url, rights_holder,
series, status, style_period.label.keyword, style_period.variants, subject.label.keyword, subject.role,
subject.variants, table_of_contents, technique.label.keyword, technique.variants, title.keyword, visibility, work_type
Examples:
- Number of collections: collection.title.keyword
- Number of works by work type: work_type
"""
try:
response = opensearch_vector_store.aggregations_search(aggregation_query)
return json.dumps(response, default=str), response
except Exception as e:
return json.dumps({"error": str(e)}), None
19 changes: 19 additions & 0 deletions chat/src/handlers/opensearch_neural_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def similarity_search_with_score(

return documents_with_scores

def aggregations_search(self, field: str, **kwargs: Any) -> dict:
"""Perform a search with aggregations and return the aggregation results."""
dsl = {
"size": 0,
"aggs": {"aggregation_result": {"terms": {"field": field}}},
}

response = self.client.search(
index=self.index,
body=dsl,
params=(
{"search_pipeline": self.search_pipeline}
if self.search_pipeline
else None
),
)

return response.get("aggregations", {})

def add_texts(self, texts: List[str], metadatas: List[dict], **kwargs: Any) -> None:
pass

Expand Down
9 changes: 5 additions & 4 deletions chat/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Runtime Dependencies
boto3~=1.34
honeybadger
honeybadger~=0.20
langchain~=0.2
langchain-aws~=0.1
langchain-openai~=0.1
langgraph~=0.2
openai~=1.35
opensearch-py
opensearch-py~=2.8
pyjwt~=2.6.0
python-dotenv~=1.0.0
requests
requests-aws4auth
requests~=2.32
requests-aws4auth~=1.3
tiktoken~=0.7
wheel~=0.40

0 comments on commit 6df7611

Please sign in to comment.