-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Knowledge gaining orchestrator working implementation
- Loading branch information
Showing
8 changed files
with
203 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import List | ||
|
||
from pathlib import Path | ||
|
||
from langchain_core.pydantic_v1 import BaseModel, Field | ||
from langchain_core.tools import Tool | ||
|
||
from motleycrew.storage import MotleyGraphStore | ||
from motleycrew.tool import MotleyTool | ||
|
||
from question_struct import Question | ||
|
||
|
||
IS_SUBQUESTION_PREDICATE = "is_subquestion" | ||
|
||
|
||
class QuestionInsertionTool(MotleyTool): | ||
def __init__(self, question: Question, graph: MotleyGraphStore): | ||
|
||
langchain_tool = create_question_insertion_langchain_tool( | ||
name="Question Insertion Tool", | ||
description="Insert a list of questions (supplied as a list of strings) into the graph.", | ||
question=question, | ||
graph=graph, | ||
) | ||
|
||
super().__init__(langchain_tool) | ||
|
||
|
||
class QuestionInsertionToolInput(BaseModel): | ||
"""Subquestions of the current question, to be inserted into the knowledge graph.""" | ||
|
||
questions: List[str] = Field(description="List of questions to be inserted into the knowledge graph.") | ||
|
||
|
||
def create_question_insertion_langchain_tool( | ||
name: str, | ||
description: str, | ||
question: Question, | ||
graph: MotleyGraphStore, | ||
): | ||
def insert_questions(questions: list[str]) -> None: | ||
for subquestion in questions: | ||
subquestion_data = graph.create_entity(Question(question=subquestion).serialize()) | ||
subquestion_obj = Question.deserialize(subquestion_data) | ||
graph.create_rel(from_id=question.id, to_id=subquestion_obj.id, predicate=IS_SUBQUESTION_PREDICATE) | ||
|
||
return Tool.from_function( | ||
func=insert_questions, | ||
name=name, | ||
description=description, | ||
args_schema=QuestionInsertionToolInput, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import kuzu | ||
from motleycrew.storage import MotleyKuzuGraphStore | ||
|
||
here = Path(__file__).parent | ||
db_path = here / "test1" | ||
db = kuzu.Database(db_path) | ||
graph_store = MotleyKuzuGraphStore( | ||
db, node_table_schema={"question": "STRING", "answer": "STRING", "context": "STRING"} | ||
) | ||
|
||
question_data = graph_store.create_entity(Question(question="What is the capital of France?").serialize()) | ||
question = Question.deserialize(question_data) | ||
|
||
children = ["What is the capital of France?", "What is the capital of Germany?"] | ||
tool = QuestionInsertionTool(question=question, graph=graph_store) | ||
tool.invoke({"questions": children}) | ||
|
||
print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest") | ||
print("MATCH (A)-[r]->(B) RETURN *;") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Optional | ||
from dataclasses import dataclass | ||
import json | ||
|
||
|
||
@dataclass | ||
class Question: | ||
id: Optional[int] = None | ||
question: Optional[str] = None | ||
answer: Optional[str] = None | ||
context: Optional[list[str]] = None | ||
|
||
def serialize(self): | ||
data = {} | ||
|
||
if self.id: | ||
data["id"] = json.dumps(self.id) | ||
if self.context: | ||
data["question"] = json.dumps(self.question) | ||
if self.context: | ||
data["answer"] = json.dumps(self.answer) | ||
if self.context: | ||
data["context"] = json.dumps(self.context) | ||
|
||
return data | ||
|
||
@staticmethod | ||
def deserialize(data: dict): | ||
context_raw = data["context"] | ||
if context_raw: | ||
context = json.loads(context_raw) | ||
else: | ||
context = None | ||
|
||
return Question(id=data["id"], question=data["question"], answer=data["answer"], context=context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .graph_store import MotleyGraphStore | ||
|
||
from .kuzu_graph_store import MotleyKuzuGraphStore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional, Any | ||
|
||
|
||
class MotleyGraphStore(ABC): | ||
@abstractmethod | ||
def check_entity_exists(self, entity_id: int) -> bool: | ||
pass | ||
|
||
@abstractmethod | ||
def get_entity(self, entity_id: int) -> Optional[dict]: | ||
pass | ||
|
||
@abstractmethod | ||
def create_entity(self, entity: dict) -> dict: | ||
"""Create a new entity and return it""" | ||
pass | ||
|
||
@abstractmethod | ||
def create_rel(self, from_id: int, to_id: int, predicate: str) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def delete_entity(self, entity_id: int) -> None: | ||
"""Delete a given entity and its relations""" | ||
pass | ||
|
||
def set_property(self, entity_id: int, property_name: str, property_value: Any): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.