-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
118 additions
and
4 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
|
||
from typing import Literal | ||
|
||
from agent.tools import search, aggregate | ||
from langgraph.checkpoint.memory import MemorySaver | ||
from langgraph.graph import END, START, StateGraph, MessagesState | ||
from langgraph.prebuilt import ToolNode | ||
from setup import openai_chat_client | ||
|
||
tools = [search, aggregate] | ||
|
||
tool_node = ToolNode(tools) | ||
|
||
model = openai_chat_client().bind_tools(tools) | ||
|
||
# Define the function that determines whether to continue or not | ||
def should_continue(state: MessagesState) -> Literal["tools", END]: | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
# If the LLM makes a tool call, then we route to the "tools" node | ||
if last_message.tool_calls: | ||
return "tools" | ||
# Otherwise, we stop (reply to the user) | ||
return END | ||
|
||
|
||
# Define the function that calls the model | ||
def call_model(state: MessagesState): | ||
messages = state["messages"] | ||
response = model.invoke(messages, model=os.getenv("AZURE_DEPLOYMENT_NAME")) | ||
# We return a list, because this will get added to the existing list | ||
return {"messages": [response]} | ||
|
||
|
||
# Define a new graph | ||
workflow = StateGraph(MessagesState) | ||
|
||
# Define the two nodes we will cycle between | ||
workflow.add_node("agent", call_model) | ||
workflow.add_node("tools", tool_node) | ||
|
||
# Set the entrypoint as `agent` | ||
workflow.add_edge(START, "agent") | ||
|
||
# Add a conditional edge | ||
workflow.add_conditional_edges( | ||
"agent", | ||
should_continue, | ||
) | ||
|
||
# Add a normal edge from `tools` to `agent` | ||
workflow.add_edge("tools", "agent") | ||
|
||
# Initialize memory to persist state between graph runs | ||
checkpointer = MemorySaver() | ||
|
||
# Compile the graph | ||
app = workflow.compile(checkpointer=checkpointer, debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import json | ||
|
||
from langchain_core.tools import tool | ||
from opensearch_client import opensearch_vector_store | ||
|
||
@tool(response_format="content_and_artifact") | ||
def search(query: str): | ||
"""Perform a semantic search of Northwestern University Library digital collections. When answering a search query, ground your answer in the context of the results with references to the document's metadata.""" | ||
query_results = opensearch_vector_store.similarity_search(query, size=20) | ||
return json.dumps(query_results, default=str), query_results | ||
|
||
@tool(response_format="content_and_artifact") | ||
def aggregate(aggregation_query: str): | ||
""" | ||
Perform a quantitative aggregation on the OpenSearch index. | ||
Available fields: | ||
api_link, api_model, ark, collection.title.keyword, contributor.label.keyword, contributor.variants, | ||
create_date, creator.variants, date_created, embedding_model, embedding_text_length, | ||
folder_name, folder_number, genre.variants, id, identifier, indexed_at, language.variants, | ||
legacy_identifier, library_unit, location.variants, modified_date, notes.note, notes.type, | ||
physical_description_material, physical_description_size, preservation_level, provenance, published, publisher, | ||
related_url.url, related_url.label, representative_file_set.aspect_ratio, representative_file_set.url, rights_holder, | ||
series, status, style_period.label.keyword, style_period.variants, subject.label.keyword, subject.role, | ||
subject.variants, table_of_contents, technique.label.keyword, technique.variants, title.keyword, visibility, work_type | ||
Examples: | ||
- Number of collections: collection.title.keyword | ||
- Number of works by work type: work_type | ||
""" | ||
try: | ||
response = opensearch_vector_store.aggregations_search(aggregation_query) | ||
return json.dumps(response, default=str), response | ||
except Exception as e: | ||
return json.dumps({"error": str(e)}), None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
# Runtime Dependencies | ||
boto3~=1.34 | ||
honeybadger | ||
honeybadger~=0.20 | ||
langchain~=0.2 | ||
langchain-aws~=0.1 | ||
langchain-openai~=0.1 | ||
langgraph~=0.2 | ||
openai~=1.35 | ||
opensearch-py | ||
opensearch-py~=2.8 | ||
pyjwt~=2.6.0 | ||
python-dotenv~=1.0.0 | ||
requests | ||
requests-aws4auth | ||
requests~=2.32 | ||
requests-aws4auth~=1.3 | ||
tiktoken~=0.7 | ||
wheel~=0.40 |