Skip to content

Commit

Permalink
removed reasoning + added async eval
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyakumar committed Oct 30, 2024
1 parent 5270b1f commit a7b0678
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 47 deletions.
78 changes: 53 additions & 25 deletions GAMER_workbook.ipynb

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions evaluations/async_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from utils import LLM
from langchain_core.prompts.prompt import PromptTemplate
from langsmith.evaluation import LangChainStringEvaluator
from metadata_chatbot.agents.GAMER import GAMER
from langsmith import aevaluate
from evaluation_dataset import dataset_name


_PROMPT_TEMPLATE = """You are an expert professor specialized in grading students' answers to questions.
You are grading the following question:
{query}
Here is the real answer:
{answer}
You are grading the following predicted answer:
{result}
Respond with CORRECT or INCORRECT:
Grade:
"""

PROMPT = PromptTemplate(
input_variables=["query", "answer", "result"], template=_PROMPT_TEMPLATE
)

evaluator = LangChainStringEvaluator("qa", config={"llm": LLM, "prompt": PROMPT})

async def my_app(question):
model = GAMER()
return await model.ainvoke(question)

async def langsmith_app(inputs):
output = await my_app(inputs["question"])
return {"output": output}

async def main():
experiment_results = await aevaluate(
langsmith_app, # Your AI system
data=dataset_name, # The data to predict and grade over
evaluators=[evaluator], # The evaluators to score the results
experiment_prefix="metadata-chatbot-0.0.44", # A prefix for your experiment names to easily identify them
)
return experiment_results

if __name__ == "__main__":
import asyncio
asyncio.run(main())
6 changes: 5 additions & 1 deletion evaluations/evaluation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
{"question": "What are the injections for SmartSPIM_675387_2023-05-23_23-05-56?"},
{"question": "What are all the assets using mouse 675387"},
{"question": "Write a MongoDB query to find the genotype of SmartSPIM_675387_2023-05-23_23-05-56"},
{"question": "How many records are stored in the database?"},
{"question": "What are the unique modalities found in the database?"},
],
outputs=[
{"answer": "The genotype for subject 675387 is wt/wt"},
Expand All @@ -36,7 +38,9 @@
2. The `$project` stage excludes the `_id` field and includes the `genotype` field from the nested `subject` object.
The retrieved output shows that the genotype for this experiment is "wt/wt".
"""},
{"answer": "There are 267 records found in the database."},
{"answer": "The unique modalities in the database are Behavior, Behavior videos, Planar optical physiology and Selective plane illumination microscopy."},
],
dataset_id=dataset.id,
)
'''
'''
2 changes: 1 addition & 1 deletion evaluations/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def langsmith_app(inputs):
langsmith_app, # Your AI system
data=dataset_name, # The data to predict and grade over
evaluators=[evaluator], # The evaluators to score the results
experiment_prefix="metadata-chatbot-0.0.41", # A prefix for your experiment names to easily identify them
experiment_prefix="metadata-chatbot-0.0.44", # A prefix for your experiment names to easily identify them
)
22 changes: 9 additions & 13 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from langchain.agents import AgentExecutor, create_tool_calling_agent
from aind_data_access_api.document_db import MetadataDbClient
from pprint import pprint
from typing_extensions import Annotated, TypedDict


logging.basicConfig(filename='agentic_graph.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

Expand All @@ -23,9 +25,9 @@
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""

reasoning: str = Field(
description="Give a justification for the chosen method",
)
# reasoning: str = Field(
# description="Give a justification for the chosen method",
# )

datasource: Literal["vectorstore", "direct_database"] = Field(
description="Given a user question choose to route it to the direct database or its vectorstore.",
Expand Down Expand Up @@ -107,20 +109,14 @@ class FilterGenerator(BaseModel):
#print(filter_generation_chain.invoke({"query": "What is the genotype for mouse 675387?"}).filter_query)

# Check if retrieved documents answer question
class RetrievalGrader(BaseModel):
class RetrievalGrader(TypedDict):
"""Binary score to check whether retrieved documents are relevant to the question"""

reasoning: str = Field(
description="Give a reasoning as to what makes the document relevant for the chosen method",
)
#reasoning: Annotated[str, ..., "Give a reasoning as to what makes the document relevant for the chosen method"]

binary_score: str = Field(
description="Retrieved documents are relevant to the query, 'yes' or 'no'"
)
binary_score: Annotated[Literal["yes", "no"], ..., "Retrieved documents are relevant to the query, 'yes' or 'no'"]

relevant_context: str = Field(
description="Relevant pieces of context in document"
)
relevant_context: Annotated[str, None, "Relevant pieces of context in document"]

retrieval_grader = LLM.with_structured_output(RetrievalGrader)
retrieval_grade_prompt = hub.pull("eden19/retrievalgrader")
Expand Down
16 changes: 9 additions & 7 deletions src/metadata_chatbot/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from metadata_chatbot.agents.docdb_retriever import DocDBRetriever

#from agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain
from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain

logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")
Expand Down Expand Up @@ -42,7 +44,7 @@ def route_question(state):
logging.info("Querying against vector embeddings...")
return "vectorstore"

def generate_for_whole_db(state):
def retrieve_DB(state):
"""
Filter database
Expand Down Expand Up @@ -90,7 +92,7 @@ def filter_generator(state):
else:
return {"filter": None, "query": query}

def retrieve(state):
def retrieve_VI(state):
"""
Retrieve documents
Expand Down Expand Up @@ -134,11 +136,11 @@ def grade_documents(state):
filtered_docs = []
for doc in documents:
score = doc_grader.invoke({"query": query, "document": doc.page_content})
grade = score.binary_score
grade = score['binary_score']
logging.info(f"Retrieved document matched query: {grade}")
if grade == "yes":
logging.info("Document is relevant to the query")
relevant_context = score.relevant_context
relevant_context = score['relevant_context']
filtered_docs.append(relevant_context)
else:
logging.info("Document is not relevant and will be removed")
Expand Down Expand Up @@ -183,9 +185,9 @@ def generate_vi(state):
return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)}

workflow = StateGraph(GraphState)
workflow.add_node("database_query", generate_for_whole_db)
workflow.add_node("database_query", retrieve_DB)
workflow.add_node("filter_generation", filter_generator)
workflow.add_node("retrieve", retrieve)
workflow.add_node("retrieve", retrieve_VI)
workflow.add_node("document_grading", grade_documents)
workflow.add_node("generate_db", generate_db)
workflow.add_node("generate_vi", generate_vi)
Expand All @@ -208,7 +210,7 @@ def generate_vi(state):

app = workflow.compile()

# query = "Write a MongoDB query to find the genotype of SmartSPIM_675387_2023-05-23_23-05-56"
# query = "What are the injections used in asset SmartSPIM_692908_2023-11-08_16-48-13_stitched_2023-11-09_11-12-06"

# inputs = {"query" : query}
# answer = app.invoke(inputs)
Expand Down

0 comments on commit a7b0678

Please sign in to comment.