Skip to content

Commit

Permalink
upgraded sonnet 3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyakumar committed Nov 1, 2024
1 parent 0802941 commit 5325051
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 67 deletions.
26 changes: 8 additions & 18 deletions GAMER_workbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"output_type": "stream",
"text": [
"Name: metadata-chatbot\n",
"Version: 0.0.45\n",
"Version: 0.0.49\n",
"Summary: Generated from aind-library-template\n",
"Home-page: \n",
"Author: Allen Institute for Neural Dynamics\n",
Expand All @@ -176,19 +176,15 @@
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:111: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n",
" documents = retriever.get_relevant_documents(query = query, query_filter = filter)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the provided context, I am unable to determine all the assets using mouse 675387. The context provides details about the procedures performed on mouse 675387, such as virus injections, fixation, delipidation, and refractive index matching, as well as some acquisition details like imaging channels and coordinate transformations. However, it does not explicitly list or mention all the assets associated with this mouse subject.\n"
"The retrieved output shows two assets using mouse 675387:\n",
"\n",
"1. An asset with id \"14c42287-1157-4985-b0f2-633ef9b289e3\", named \"SmartSPIM_675387_2023-05-23_23-05-56\", with a modality of \"Selective plane illumination microscopy\".\n",
"\n",
"2. An asset with id \"ac3f5ca8-6c9c-42d5-b3d0-952565ac4d59\", named \"SmartSPIM_675387_2023-05-23_23-05-56_stitched_2023-06-01_22-30-44\", also with a modality of \"Selective plane illumination microscopy\".\n"
]
}
],
Expand Down Expand Up @@ -233,20 +229,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the provided context, the channels imaged in asset SmartSPIM_692908_2023-11-08_16-48-13_stitched_2023-11-09_11-12-06 were:\n",
"\n",
"Channel 488 nm: This channel likely images a green fluorescent protein label, potentially EGFP expressed under the TICL promoter.\n",
"\n",
"Channel 561 nm: This channel likely images a red fluorescent protein label like tdTomato, potentially expressed under the ICF promoter. \n",
"\n",
"Channel 639 nm: This channel likely images a far-red fluorescent protein label.\n"
"Based on the provided context, I could not find any information about the age of the subject at the time of imaging for the specific session SmartSPIM_662616_2023-03-06_17-47-13. The context contains conflicting information about the subject's date of birth and does not explicitly state the age for that particular session. I am unable to provide a definitive answer to the question.\n"
]
}
],
"source": [
"from metadata_chatbot.agents.GAMER import GAMER\n",
"llm = GAMER()\n",
"query = \"Which channels were imaged in asset SmartSPIM_692908_2023-11-08_16-48-13_stitched_2023-11-09_11-12-06? What is labelled in each channel?\"\n",
"query = \"What was the age of the subject at the time of imaging in SmartSPIM_662616_2023-03-06_17-47-13\"\n",
"\n",
"result = await llm.ainvoke(query)\n",
"print(result)"
Expand Down
7 changes: 2 additions & 5 deletions embeddings/langchain_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@
from bson import json_util
from langchain_text_splitters import RecursiveJsonSplitter

sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from metadata_chatbot.utils import create_ssh_tunnel, CONNECTION_STRING, BEDROCK_CLIENT, ResourceManager
#sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from utils import create_ssh_tunnel, CONNECTION_STRING, BEDROCK_EMBEDDINGS, ResourceManager

logging.basicConfig(filename='vector_store.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

TOKEN_LIMIT = 8192
JSON_SPLITTER = RecursiveJsonSplitter(max_chunk_size=TOKEN_LIMIT)

BEDROCK_EMBEDDINGS = BedrockEmbeddings(model_id = "amazon.titan-embed-text-v2:0",client = BEDROCK_CLIENT)


def regex_modality_PHYSIO(record_name: str) -> bool:

PHYSIO_modalities = ["behavior", "Other", "FIP", "phys", "HSFP"]
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion evaluations/async_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def main():
langsmith_app, # Your AI system
data=dataset_name, # The data to predict and grade over
evaluators=[evaluator], # The evaluators to score the results
experiment_prefix="metadata-chatbot-0.0.44", # A prefix for your experiment names to easily identify them
experiment_prefix="async-metadata-chatbot-0.0.49", # A prefix for your experiment names to easily identify them
)
return experiment_results

Expand Down
2 changes: 1 addition & 1 deletion evaluations/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def langsmith_app(inputs):
langsmith_app, # Your AI system
data=dataset_name, # The data to predict and grade over
evaluators=[evaluator], # The evaluators to score the results
experiment_prefix="metadata-chatbot-0.0.48", # A prefix for your experiment names to easily identify them
experiment_prefix="metadata-chatbot-0.0.49", # A prefix for your experiment names to easily identify them
)
52 changes: 28 additions & 24 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@
from pprint import pprint
from typing_extensions import Annotated, TypedDict


logging.basicConfig(filename='agentic_graph.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
LLM = ChatBedrock(
model_id= MODEL_ID,
MODEL_ID_SONNET_3 = "anthropic.claude-3-sonnet-20240229-v1:0"
MODEL_ID_SONNET_3_5 = "anthropic.claude-3-sonnet-20240229-v1:0"
SONNET_3_LLM = ChatBedrock(
model_id= MODEL_ID_SONNET_3,
model_kwargs= {
"temperature": 0
}
)

SONNET_3_5_LLM = ChatBedrock(
model_id= MODEL_ID_SONNET_3_5,
model_kwargs= {
"temperature": 0
}
Expand All @@ -33,13 +40,11 @@ class RouteQuery(BaseModel):
description="Given a user question choose to route it to the direct database or its vectorstore.",
)

structured_llm_router = LLM.with_structured_output(RouteQuery)
structured_llm_router = SONNET_3_LLM.with_structured_output(RouteQuery)
router_prompt = hub.pull("eden19/query_rerouter")
datasource_router = router_prompt | structured_llm_router
#print(datasource_router.invoke({"query": "What is the mongodb query to find the injections for SmartSPIM_675387_2023-05-23_23-05-56?"}).datasource)



# Queries that require surveying the entire database (like count based questions)
# credentials = DocumentDbSSHCredentials()
# credentials.database = "metadata_vector_index"
Expand Down Expand Up @@ -79,20 +84,21 @@ def aggregation_retrieval(agg_pipeline: list) -> list:

tools = [aggregation_retrieval]
db_prompt = hub.pull("eden19/entire_db_retrieval")
db_surveyor_agent = create_tool_calling_agent(LLM, tools, db_prompt)
db_surveyor_agent = create_tool_calling_agent(SONNET_3_LLM, tools, db_prompt)
query_retriever = AgentExecutor(agent=db_surveyor_agent, tools=tools, return_intermediate_steps = True, verbose=False)

# Processing query
class ProcessQuery(BaseModel):
"""Binary score to check whether query requires retrieval to be filtered with metadata information to achieve accurate results."""
# class ProcessQuery(BaseModel):
# """Binary score to check whether query requires retrieval to be filtered with metadata information to achieve accurate results."""

binary_score: str = Field(
description="Query requires further filtering during retrieval process, 'yes' or 'no'"
)
# binary_score: str = Field(
# description="Query requires further filtering during retrieval process, 'yes' or 'no'"
# )
# reasoning: str = Field("One short sentence justifying why a filter was picked or not picked")

query_grader = LLM.with_structured_output(ProcessQuery)
query_grade_prompt = hub.pull("eden19/processquery")
query_grader = query_grade_prompt | query_grader
# query_grader = SONNET_3_5_LLM.with_structured_output(ProcessQuery)
# query_grade_prompt = hub.pull("eden19/processquery")
# query_grader = query_grade_prompt | query_grader
#print(query_grader.invoke({"query": "What is the genotype for mouse 675387?"}).binary_score)

# Generating appropriate filter
Expand All @@ -103,7 +109,7 @@ class FilterGenerator(BaseModel):
#top_k: int = Field(description="Number of documents to retrieve from the database")

filter_prompt = hub.pull("eden19/filtergeneration")
filter_generator_llm = LLM.with_structured_output(FilterGenerator)
filter_generator_llm = SONNET_3_LLM.with_structured_output(FilterGenerator)

filter_generation_chain = filter_prompt | filter_generator_llm
#print(filter_generation_chain.invoke({"query": "What is the genotype for mouse 675387?"}).filter_query)
Expand All @@ -112,21 +118,19 @@ class FilterGenerator(BaseModel):
class RetrievalGrader(TypedDict):
"""Binary score to check whether retrieved documents are relevant to the question"""

#reasoning: Annotated[str, ..., "Give a reasoning as to what makes the document relevant for the chosen method"]

relevant_context:Annotated[str, ..., "Relevant context extracted from document that helps directly answer the question"]
binary_score: Annotated[Literal["yes", "no"], ..., "Retrieved documents are relevant to the query, 'yes' or 'no'"]
#relevant_context: Annotated[str, None, "Summarize relevant pieces of context in document"]

relevant_context: Annotated[str, None, "Relevant pieces of context in document"]

retrieval_grader = LLM.with_structured_output(RetrievalGrader)
retrieval_grader = SONNET_3_5_LLM.with_structured_output(RetrievalGrader)
retrieval_grade_prompt = hub.pull("eden19/retrievalgrader")
doc_grader = retrieval_grade_prompt | retrieval_grader
# doc_grade = doc_grader.invoke({"query": question, "document": doc}).binary_score
# logging.info(f"Retrieved document matched query: {doc_grade}")

# Generating response to documents
answer_generation_prompt = hub.pull("eden19/answergeneration")
rag_chain = answer_generation_prompt | LLM | StrOutputParser()
rag_chain = answer_generation_prompt | SONNET_3_LLM | StrOutputParser()

db_answer_generation_prompt = hub.pull("eden19/db_answergeneration")
# class DatabaseGeneration(BaseModel):
Expand All @@ -141,6 +145,6 @@ class RetrievalGrader(TypedDict):
# )

# database_answer_generation = LLM.with_structured_output(DatabaseGeneration)
db_rag_chain = db_answer_generation_prompt | LLM | StrOutputParser()
db_rag_chain = db_answer_generation_prompt | SONNET_3_5_LLM | StrOutputParser()
# generation = rag_chain.invoke({"documents": doc, "query": question})
# logging.info(f"Final answer: {generation}")
11 changes: 8 additions & 3 deletions src/metadata_chatbot/agents/docdb_retriever.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys, os, json
import sys, os, json, boto3
from typing import List, Optional, Any, Union, Annotated
from pymongo.collection import Collection
from motor.motor_asyncio import AsyncIOMotorCollection
Expand All @@ -9,9 +9,14 @@
from langsmith import trace as langsmith_trace
from pydantic import Field
from aind_data_access_api.document_db import MetadataDbClient
from langchain_aws import BedrockEmbeddings

sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from metadata_chatbot.utils import BEDROCK_EMBEDDINGS
BEDROCK_CLIENT = boto3.client(
service_name="bedrock-runtime",
region_name = 'us-west-2'
)

BEDROCK_EMBEDDINGS = BedrockEmbeddings(model_id="amazon.titan-embed-text-v2:0",client=BEDROCK_CLIENT)

API_GATEWAY_HOST = "api.allenneuraldynamics-test.org"
DATABASE = "metadata_vector_index"
Expand Down
31 changes: 16 additions & 15 deletions src/metadata_chatbot/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from langgraph.graph import END, StateGraph, START
from metadata_chatbot.agents.docdb_retriever import DocDBRetriever

#from agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain
from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain
from agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain
#from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain

logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

Expand Down Expand Up @@ -81,16 +81,16 @@ def filter_generator(state):

query = state["query"]

query_grade = query_grader.invoke({"query": query}).binary_score
logging.info(f"Database needs to be further filtered: {query_grade}")
# query_grade = query_grader.invoke({"query": query}).binary_score
# logging.info(f"Database needs to be further filtered: {query_grade}")

if query_grade == "yes":
filter = filter_generation_chain.invoke({"query": query}).filter_query
#top_k = filter_generation_chain.invoke({"query": query}).top_k
logging.info(f"Database will be filtered using: {filter}")
return {"filter": filter, "query": query}
else:
return {"filter": None, "query": query}
#if query_grade == "yes":
filter = filter_generation_chain.invoke({"query": query}).filter_query
#top_k = filter_generation_chain.invoke({"query": query}).top_k
logging.info(f"Database will be filtered using: {filter}")
return {"filter": filter, "query": query}
# else:
# return {"filter": None, "query": query}

def retrieve_VI(state):
"""
Expand Down Expand Up @@ -146,6 +146,7 @@ def grade_documents(state):
logging.info("Document is not relevant and will be removed")
continue
#doc_text = "\n\n".join(doc.page_content for doc in filtered_docs)
#print(filtered_docs)
return {"documents": filtered_docs, "query": query}

def generate_db(state):
Expand Down Expand Up @@ -210,8 +211,8 @@ def generate_vi(state):

app = workflow.compile()

# query = "What are all the assets using mouse 675387"
query = "How was the tissue prepared for imaging, including fixation, delipidation, and refractive index matching procedures? in experiment: SmartSPIM_675388_2023-05-24_04-10-19_stitched_2023-05-28_18-07-46"

# inputs = {"query" : query}
# answer = app.invoke(inputs)
# print(answer['generation'])
inputs = {"query" : query}
answer = app.invoke(inputs)
print(answer['generation'])

0 comments on commit 5325051

Please sign in to comment.