diff --git a/examples/research_agent/question_generator.py b/examples/research_agent/question_generator.py index 9a2223d1..566e1451 100644 --- a/examples/research_agent/question_generator.py +++ b/examples/research_agent/question_generator.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict, Any +from typing import Optional, Any import json from pathlib import Path @@ -6,7 +6,6 @@ from langchain_core.runnables import ( RunnablePassthrough, RunnableLambda, - RunnableParallel, ) from langchain_core.tools import Tool from langchain_core.prompts.base import BasePromptTemplate @@ -14,19 +13,20 @@ from langchain_core.pydantic_v1 import BaseModel, Field -# TODO: fallback interface if LlamaIndex is not available -from llama_index.core.graph_stores.types import GraphStore - from motleycrew.tool import MotleyTool from motleycrew.common import LLMFramework from motleycrew.common.llms import init_llm -from motleycrew.tool.question_insertion_tool import QuestionInsertionTool from motleycrew.common.utils import print_passthrough +from motleycrew.storage import MotleyGraphStore + +from question_struct import Question +from question_inserter import QuestionInsertionTool + default_prompt = PromptTemplate.from_template( """ You are a part of a team. The ultimate goal of your team is to -answer the following Question: '{question}'.\n +answer the following Question: '{question_text}'.\n Your team has discovered some new text (delimited by ```) that may be relevant to your ultimate goal. text: \n ``` {context} ``` \n Your task is to ask new questions that may help your team achieve the ultimate goal. @@ -57,7 +57,7 @@ class QuestionGeneratorTool(MotleyTool): def __init__( self, query_tool: MotleyTool, - graph: GraphStore, + graph: MotleyGraphStore, max_questions: int = 3, llm: Optional[BaseLanguageModel] = None, prompt: str | BasePromptTemplate = None, @@ -76,14 +76,12 @@ def __init__( class QuestionGeneratorToolInput(BaseModel): """Input for the Question Generator Tool.""" - question: str = Field( - description="The input question for which to generate subquestions." - ) + question: Question = Field(description="The input question for which to generate subquestions.") def create_question_generator_langchain_tool( query_tool: MotleyTool, - graph: GraphStore, + graph: MotleyGraphStore, max_questions: int = 3, llm: Optional[BaseLanguageModel] = None, prompt: str | BasePromptTemplate = None, @@ -98,14 +96,10 @@ def create_question_generator_langchain_tool( elif isinstance(prompt, str): prompt = PromptTemplate.from_template(prompt) - assert isinstance( - prompt, BasePromptTemplate - ), "Prompt must be a string or a BasePromptTemplate" + assert isinstance(prompt, BasePromptTemplate), "Prompt must be a string or a BasePromptTemplate" - def partial_inserter(question: dict[str, str]): - out = QuestionInsertionTool( - graph=graph, question=question["question"] - ).to_langchain_tool() + def partial_inserter(question: Question): + out = QuestionInsertionTool(graph=graph, question=question).to_langchain_tool() return (out,) def insert_questions(input_dict) -> None: @@ -124,7 +118,10 @@ def insert_questions(input_dict) -> None: } | RunnableLambda(print_passthrough) | { - "subquestions": prompt.partial(num_questions=max_questions) | llm, + "subquestions": RunnablePassthrough.assign(question_text=lambda x: x["question"]["question"].question) + | RunnableLambda(print_passthrough) + | prompt.partial(num_questions=max_questions) + | llm, "question_inserter": RunnablePassthrough(), } | RunnableLambda(insert_questions) diff --git a/examples/research_agent/question_inserter.py b/examples/research_agent/question_inserter.py new file mode 100644 index 00000000..4913452f --- /dev/null +++ b/examples/research_agent/question_inserter.py @@ -0,0 +1,75 @@ +from typing import List + +from pathlib import Path + +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import Tool + +from motleycrew.storage import MotleyGraphStore +from motleycrew.tool import MotleyTool + +from question_struct import Question + + +IS_SUBQUESTION_PREDICATE = "is_subquestion" + + +class QuestionInsertionTool(MotleyTool): + def __init__(self, question: Question, graph: MotleyGraphStore): + + langchain_tool = create_question_insertion_langchain_tool( + name="Question Insertion Tool", + description="Insert a list of questions (supplied as a list of strings) into the graph.", + question=question, + graph=graph, + ) + + super().__init__(langchain_tool) + + +class QuestionInsertionToolInput(BaseModel): + """Subquestions of the current question, to be inserted into the knowledge graph.""" + + questions: List[str] = Field(description="List of questions to be inserted into the knowledge graph.") + + +def create_question_insertion_langchain_tool( + name: str, + description: str, + question: Question, + graph: MotleyGraphStore, +): + def insert_questions(questions: list[str]) -> None: + for subquestion in questions: + subquestion_data = graph.create_entity(Question(question=subquestion).serialize()) + subquestion_obj = Question.deserialize(subquestion_data) + graph.create_rel(from_id=question.id, to_id=subquestion_obj.id, predicate=IS_SUBQUESTION_PREDICATE) + + return Tool.from_function( + func=insert_questions, + name=name, + description=description, + args_schema=QuestionInsertionToolInput, + ) + + +if __name__ == "__main__": + import kuzu + from motleycrew.storage import MotleyKuzuGraphStore + + here = Path(__file__).parent + db_path = here / "test1" + db = kuzu.Database(db_path) + graph_store = MotleyKuzuGraphStore( + db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"} + ) + + question_data = graph_store.create_entity(Question(question="What is the capital of France?").serialize()) + question = Question.deserialize(question_data) + + children = ["What is the capital of France?", "What is the capital of Germany?"] + tool = QuestionInsertionTool(question=question, graph=graph_store) + tool.invoke({"questions": children}) + + print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest") + print("MATCH (A)-[r]->(B) RETURN *;") diff --git a/examples/research_agent/question_struct.py b/examples/research_agent/question_struct.py new file mode 100644 index 00000000..a460586d --- /dev/null +++ b/examples/research_agent/question_struct.py @@ -0,0 +1,35 @@ +from typing import Optional +from dataclasses import dataclass +import json + + +@dataclass +class Question: + id: Optional[int] = None + question: Optional[str] = None + answer: Optional[str] = None + context: Optional[list[str]] = None + + def serialize(self): + data = {} + + if self.id: + data["id"] = json.dumps(self.id) + if self.context: + data["question"] = json.dumps(self.question) + if self.context: + data["answer"] = json.dumps(self.answer) + if self.context: + data["context"] = json.dumps(self.context) + + return data + + @staticmethod + def deserialize(data: dict): + context_raw = data["context"] + if context_raw: + context = json.loads(context_raw) + else: + context = None + + return Question(id=data["id"], question=data["question"], answer=data["answer"], context=context) diff --git a/examples/research_agent/research_agent.py b/examples/research_agent/research_agent.py index 491d9846..fc9c6285 100644 --- a/examples/research_agent/research_agent.py +++ b/examples/research_agent/research_agent.py @@ -3,10 +3,16 @@ import kuzu from langchain.prompts import PromptTemplate +from langchain.tools import Tool +from motleycrew import MotleyTool from motleycrew.storage import MotleyKuzuGraphStore from motleycrew.tool.llm_tool import LLMTool +from question_struct import Question +from question_generator import QuestionGeneratorTool +from question_generator import QuestionGeneratorToolInput + logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -27,20 +33,21 @@ class KnowledgeGainingOrchestrator: - def __init__(self, db_path: str): + def __init__(self, db_path: str, query_tool: MotleyTool): self.db = kuzu.Database(db_path) self.storage = MotleyKuzuGraphStore( self.db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"} ) + self.query_tool = query_tool self.question_prioritization_tool = LLMTool( name="question_prioritization_tool", description="find the most important question", prompt=QUESTION_PRIORITIZATION_TEMPLATE, ) - self.question_generation_tool = None + self.question_generation_tool = QuestionGeneratorTool(query_tool=query_tool, graph=self.storage) - def get_unanswered_questions(self, only_without_children: bool = False) -> list[dict]: + def get_unanswered_questions(self, only_without_children: bool = False) -> list[Question]: if only_without_children: query = "MATCH (n1:{}) WHERE n1.answer IS NULL AND NOT (n1)-[:{}]->(:{}) RETURN n1;".format( self.storage.node_table_name, self.storage.rel_table_name, self.storage.node_table_name @@ -49,7 +56,7 @@ def get_unanswered_questions(self, only_without_children: bool = False) -> list[ query = "MATCH (n1:{}) WHERE n1.answer IS NULL RETURN n1;".format(self.storage.node_table_name) query_result = self.storage.run_query(query) - return [row[0] for row in query_result] # flatten + return [Question.deserialize(row[0]) for row in query_result] def __call__(self, query: str, max_iter: int): self.storage.create_entity({"question": query}) @@ -60,25 +67,50 @@ def __call__(self, query: str, max_iter: int): unanswered_questions = self.get_unanswered_questions(only_without_children=True) logging.info("Loaded unanswered questions: %s", unanswered_questions) - tool_input = "\n".join(f"{i}. {question}" for i, question in enumerate(unanswered_questions)) - most_pertinent_question_raw = self.question_prioritization_tool.invoke(tool_input) + question_prioritization_tool_input = { + "unanswered_questions": "\n".join( + f"{i}. {question.question}" for i, question in enumerate(unanswered_questions) + ), + "original_question": query, + } + most_pertinent_question_raw = self.question_prioritization_tool.invoke( + question_prioritization_tool_input + ).content logging.info("Most pertinent question according to the tool: %s", most_pertinent_question_raw) i, most_pertinent_question_text = most_pertinent_question_raw.split(".", 1) + i = int(i) assert i < len(unanswered_questions) most_pertinent_question = unanswered_questions[i] - assert most_pertinent_question_text.strip() == most_pertinent_question["question"].strip() + assert most_pertinent_question_text.strip() == most_pertinent_question.question.strip() logging.info("Generating new questions") + self.question_generation_tool.invoke({"question": most_pertinent_question}) if __name__ == "__main__": from pathlib import Path import shutil + from dotenv import load_dotenv + load_dotenv() here = Path(__file__).parent db_path = here / "research_db" shutil.rmtree(db_path, ignore_errors=True) - orchestrator = KnowledgeGainingOrchestrator(db_path=str(db_path)) + query_tool = MotleyTool.from_langchain_tool( + Tool.from_function( + func=lambda question: [ + "Germany has consisted of many different states over the years", + "The capital of France has moved in 1815, from Lyons to Paris", + "France actually has two capitals, one in the north and one in the south", + ], + name="Query Tool", + description="Query the library for relevant information.", + args_schema=QuestionGeneratorToolInput, + ) + ) + + orchestrator = KnowledgeGainingOrchestrator(db_path=str(db_path), query_tool=query_tool) + orchestrator(query="Why did Arjuna kill his step-brother?", max_iter=5) diff --git a/motleycrew/storage/__init__.py b/motleycrew/storage/__init__.py index 1f544aec..518564ea 100644 --- a/motleycrew/storage/__init__.py +++ b/motleycrew/storage/__init__.py @@ -1 +1,3 @@ +from .graph_store import MotleyGraphStore + from .kuzu_graph_store import MotleyKuzuGraphStore diff --git a/motleycrew/storage/graph_store.py b/motleycrew/storage/graph_store.py new file mode 100644 index 00000000..fd7fe6b2 --- /dev/null +++ b/motleycrew/storage/graph_store.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Optional, Any + + +class MotleyGraphStore(ABC): + @abstractmethod + def check_entity_exists(self, entity_id: int) -> bool: + pass + + @abstractmethod + def get_entity(self, entity_id: int) -> Optional[dict]: + pass + + @abstractmethod + def create_entity(self, entity: dict) -> dict: + """Create a new entity and return it""" + pass + + @abstractmethod + def create_rel(self, from_id: int, to_id: int, predicate: str) -> None: + pass + + @abstractmethod + def delete_entity(self, entity_id: int) -> None: + """Delete a given entity and its relations""" + pass + + def set_property(self, entity_id: int, property_name: str, property_value: Any): + pass diff --git a/motleycrew/storage/kuzu_graph_store.py b/motleycrew/storage/kuzu_graph_store.py index 96b9f48f..31ecbe44 100644 --- a/motleycrew/storage/kuzu_graph_store.py +++ b/motleycrew/storage/kuzu_graph_store.py @@ -7,8 +7,10 @@ import kuzu +from motleycrew.storage import MotleyGraphStore -class MotleyKuzuGraphStore: + +class MotleyKuzuGraphStore(MotleyGraphStore): def __init__( self, database: Any, @@ -81,11 +83,11 @@ def _dict_to_cypher_mapping_with_parameters(entity: dict) -> tuple[str, dict]: cypher_mapping = cypher_mapping.rstrip(", ") + "}" return cypher_mapping, parameters - def create_entity(self, entity: dict) -> int: + def create_entity(self, entity: dict) -> dict: """Create a new entity and return its id""" cypher_mapping, parameters = MotleyKuzuGraphStore._dict_to_cypher_mapping_with_parameters(entity) create_result = self.connection.execute( - "CREATE (n:{} {}) RETURN n.id".format(self.node_table_name, cypher_mapping), parameters=parameters + "CREATE (n:{} {}) RETURN n".format(self.node_table_name, cypher_mapping), parameters=parameters ) assert create_result.has_next() return create_result.get_next()[0] diff --git a/motleycrew/tool/question_insertion_tool.py b/motleycrew/tool/question_insertion_tool.py deleted file mode 100644 index 3d2f57f2..00000000 --- a/motleycrew/tool/question_insertion_tool.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import List - -from pathlib import Path - -from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_core.tools import Tool - -# TODO: fallback interface if LlamaIndex is not available -from llama_index.core.graph_stores.types import GraphStore - -from motleycrew.tool import MotleyTool - - -class QuestionInsertionTool(MotleyTool): - def __init__(self, question: str, graph: GraphStore): - - langchain_tool = create_question_insertion_langchain_tool( - name="Question Insertion Tool", - description="Insert a list of questions (supplied as a list of strings) into the graph.", - question=question, - graph=graph, - ) - - super().__init__(langchain_tool) - - -class QuestionInsertionToolInput(BaseModel): - """Subquestions of the current question, to be inserted into the knowledge graph.""" - - questions: List[str] = Field( - description="List of questions to be inserted into the knowledge graph." - ) - - -def create_question_insertion_langchain_tool( - name: str, - description: str, - question: str, - graph: GraphStore, -): - def insert_questions(questions: list[str]) -> None: - for subquestion in questions: - # TODO: change! This is a placeholder implementation - graph.upsert_triplet(question, "IS_SUBQUESTION", subquestion) - - return Tool.from_function( - func=insert_questions, - name=name, - description=description, - args_schema=QuestionInsertionToolInput, - ) - - -if __name__ == "__main__": - import kuzu - from llama_index.graph_stores.kuzu import KuzuGraphStore - - here = Path(__file__).parent - db_path = here / "test1" - db = kuzu.Database(db_path) - graph_store = KuzuGraphStore(db) - - children_1 = ["What is the capital of France?", "What is the capital of Germany?"] - children_2 = ["What is the capital of Italy?", "What is the capital of Spain?"] - tool = QuestionInsertionTool(question="Starting question", graph=graph_store) - tool.invoke({"questions": children_1}) - tool2 = QuestionInsertionTool( - question="What is the capital of France?", graph=graph_store - ) - tool2.invoke({"questions": children_2}) - print( - f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest" - ) - print("MATCH (A)-[r]->(B) RETURN *;")