Skip to content

Commit

Permalink
Merge pull request #4 from AllenNeuralDynamics/error_handling
Browse files Browse the repository at this point in the history
efficiency improved + error handling
  • Loading branch information
sreyakumar authored Nov 25, 2024
2 parents 44abed7 + e66f34f commit 82e3cfb
Show file tree
Hide file tree
Showing 23 changed files with 248 additions and 3,078 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# GAMER: Generative Analysis for Metadata Retrieval
# GAMER: Generative Analysis for Metadata Retrieval

[![License](https://img.shields.io/badge/license-MIT-brightgreen)](LICENSE)
![Code Style](https://img.shields.io/badge/code%20style-black-black)
Expand Down Expand Up @@ -29,6 +29,8 @@ Install the chatbot package -- ensure virtual environment is running.
pip install metadata-chatbot
```

Create a folder called huggingface_cache in the directory in which you are running the model.

## Usage

To call the model,
Expand Down
39 changes: 39 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Import the Streamlit library
import streamlit as st
from metadata_chatbot.agents.GAMER import GAMER
import asyncio

#run on terminal with streamlit run <FILE PATH> [ARGUMENTS]

async def main():
# Write a simple message to the app's webpage
llm = GAMER()
message = st.chat_message("assistant")
message.write("Hello!")

prompt = st.chat_input("Ask a question about the AIND Metadata!")

if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

if prompt:
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
response = await llm.ainvoke(prompt)

with st.chat_message("assistant"):
st.markdown(response)

# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})


if __name__ == "__main__":
asyncio.run(main())
54 changes: 6 additions & 48 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
from langgraph.prebuilt import create_react_agent

MODEL_ID_SONNET_3 = "anthropic.claude-3-sonnet-20240229-v1:0"
MODEL_ID_SONNET_3_5 = "anthropic.claude-3-sonnet-20240229-v1:0"
MODEL_ID_SONNET_3_5 = "anthropic.claude-3-5-sonnet-20240620-v1:0"
SONNET_3_LLM = ChatBedrock(
model_id= MODEL_ID_SONNET_3,
model_kwargs= {
"temperature": 0
}
},
streaming = True
)

SONNET_3_5_LLM = ChatBedrock(
model_id= MODEL_ID_SONNET_3_5,
model_kwargs= {
"temperature": 0
}
},
streaming = True
)

# Determining if entire database needs to be surveyed
Expand All @@ -35,50 +37,6 @@ class RouteQuery(TypedDict):
router_prompt = hub.pull("eden19/query_rerouter")
datasource_router = router_prompt | structured_llm_router

# Tool to survey entire database
API_GATEWAY_HOST = "api.allenneuraldynamics.org"
DATABASE = "metadata_index"
COLLECTION = "data_assets"

docdb_api_client = MetadataDbClient(
host=API_GATEWAY_HOST,
database=DATABASE,
collection=COLLECTION,
)

@tool
def aggregation_retrieval(agg_pipeline: list) -> list:
"""Given a MongoDB query and list of projections, this function retrieves and returns the
relevant information in the documents.
Use a project stage as the first stage to minimize the size of the queries before proceeding with the remaining steps.
The input to $map must be an array not a string, avoid using it in the $project stage.
Parameters
----------
agg_pipeline
MongoDB aggregation pipeline
Returns
-------
list
List of retrieved documents
"""

result = docdb_api_client.aggregate_docdb_records(
pipeline=agg_pipeline
)
return result

tools = [aggregation_retrieval]
tool_model = SONNET_3_5_LLM.bind_tools(tools)

db_prompt = hub.pull("eden19/entire_db_retrieval")
langgraph_agent_executor = create_react_agent(SONNET_3_LLM, tools=tools, state_modifier= db_prompt)

db_surveyor_agent = create_tool_calling_agent(SONNET_3_LLM, tools, db_prompt)
query_retriever = AgentExecutor(agent=db_surveyor_agent, tools=tools, return_intermediate_steps = True, verbose=False)


# Generating appropriate filter
class FilterGenerator(TypedDict):
"""MongoDB filter to be applied before vector retrieval"""
Expand All @@ -95,7 +53,7 @@ class FilterGenerator(TypedDict):
class RetrievalGrader(TypedDict):
"""Relevant material in the retrieved document + Binary score to check relevance to the question"""

relevant_context:Annotated[str, ..., "Relevant context extracted from document that helps directly answer the question"]
#relevant_context:Annotated[str, ..., "Relevant context extracted from document that helps directly answer the question"]
binary_score: Annotated[Literal["yes", "no"], ..., "Retrieved documents are relevant to the query, 'yes' or 'no'"]

retrieval_grader = SONNET_3_5_LLM.with_structured_output(RetrievalGrader)
Expand Down
Loading

0 comments on commit 82e3cfb

Please sign in to comment.