Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

efficiency improved + error handling #4

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# GAMER: Generative Analysis for Metadata Retrieval
# GAMER: Generative Analysis for Metadata Retrieval

[![License](https://img.shields.io/badge/license-MIT-brightgreen)](LICENSE)
![Code Style](https://img.shields.io/badge/code%20style-black-black)
Expand Down Expand Up @@ -29,6 +29,8 @@ Install the chatbot package -- ensure virtual environment is running.
pip install metadata-chatbot
```

Create a folder called huggingface_cache in the directory in which you are running the model.

## Usage

To call the model,
Expand Down
39 changes: 39 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Import the Streamlit library
import streamlit as st
from metadata_chatbot.agents.GAMER import GAMER
import asyncio

#run on terminal with streamlit run <FILE PATH> [ARGUMENTS]

async def main():
# Write a simple message to the app's webpage
llm = GAMER()
message = st.chat_message("assistant")
message.write("Hello!")

prompt = st.chat_input("Ask a question about the AIND Metadata!")

if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

if prompt:
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
response = await llm.ainvoke(prompt)

with st.chat_message("assistant"):
st.markdown(response)

# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})


if __name__ == "__main__":
asyncio.run(main())
54 changes: 6 additions & 48 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
from langgraph.prebuilt import create_react_agent

MODEL_ID_SONNET_3 = "anthropic.claude-3-sonnet-20240229-v1:0"
MODEL_ID_SONNET_3_5 = "anthropic.claude-3-sonnet-20240229-v1:0"
MODEL_ID_SONNET_3_5 = "anthropic.claude-3-5-sonnet-20240620-v1:0"
SONNET_3_LLM = ChatBedrock(
model_id= MODEL_ID_SONNET_3,
model_kwargs= {
"temperature": 0
}
},
streaming = True
)

SONNET_3_5_LLM = ChatBedrock(
model_id= MODEL_ID_SONNET_3_5,
model_kwargs= {
"temperature": 0
}
},
streaming = True
)

# Determining if entire database needs to be surveyed
Expand All @@ -35,50 +37,6 @@ class RouteQuery(TypedDict):
router_prompt = hub.pull("eden19/query_rerouter")
datasource_router = router_prompt | structured_llm_router

# Tool to survey entire database
API_GATEWAY_HOST = "api.allenneuraldynamics.org"
DATABASE = "metadata_index"
COLLECTION = "data_assets"

docdb_api_client = MetadataDbClient(
host=API_GATEWAY_HOST,
database=DATABASE,
collection=COLLECTION,
)

@tool
def aggregation_retrieval(agg_pipeline: list) -> list:
"""Given a MongoDB query and list of projections, this function retrieves and returns the
relevant information in the documents.
Use a project stage as the first stage to minimize the size of the queries before proceeding with the remaining steps.
The input to $map must be an array not a string, avoid using it in the $project stage.

Parameters
----------
agg_pipeline
MongoDB aggregation pipeline

Returns
-------
list
List of retrieved documents
"""

result = docdb_api_client.aggregate_docdb_records(
pipeline=agg_pipeline
)
return result

tools = [aggregation_retrieval]
tool_model = SONNET_3_5_LLM.bind_tools(tools)

db_prompt = hub.pull("eden19/entire_db_retrieval")
langgraph_agent_executor = create_react_agent(SONNET_3_LLM, tools=tools, state_modifier= db_prompt)

db_surveyor_agent = create_tool_calling_agent(SONNET_3_LLM, tools, db_prompt)
query_retriever = AgentExecutor(agent=db_surveyor_agent, tools=tools, return_intermediate_steps = True, verbose=False)


# Generating appropriate filter
class FilterGenerator(TypedDict):
"""MongoDB filter to be applied before vector retrieval"""
Expand All @@ -95,7 +53,7 @@ class FilterGenerator(TypedDict):
class RetrievalGrader(TypedDict):
"""Relevant material in the retrieved document + Binary score to check relevance to the question"""

relevant_context:Annotated[str, ..., "Relevant context extracted from document that helps directly answer the question"]
#relevant_context:Annotated[str, ..., "Relevant context extracted from document that helps directly answer the question"]
binary_score: Annotated[Literal["yes", "no"], ..., "Retrieved documents are relevant to the query, 'yes' or 'no'"]

retrieval_grader = SONNET_3_5_LLM.with_structured_output(RetrievalGrader)
Expand Down
Loading
Loading