From dc0c4c9eee3f6bd73bd6584339e53658072a7669 Mon Sep 17 00:00:00 2001 From: whimo Date: Thu, 25 Apr 2024 17:11:07 +0400 Subject: [PATCH 1/2] Graph store implementation draft --- motleycrew/storage/__init__.py | 2 +- motleycrew/storage/kuzu_graph_store.py | 238 +++++++++++-------------- 2 files changed, 104 insertions(+), 136 deletions(-) diff --git a/motleycrew/storage/__init__.py b/motleycrew/storage/__init__.py index 1c29f4b6..096759ad 100644 --- a/motleycrew/storage/__init__.py +++ b/motleycrew/storage/__init__.py @@ -1 +1 @@ -from kuzu_graph_store import MotleyKuzuGraphStore +from .kuzu_graph_store import MotleyQuestionGraphStore diff --git a/motleycrew/storage/kuzu_graph_store.py b/motleycrew/storage/kuzu_graph_store.py index 8d29f646..3e0f7a24 100644 --- a/motleycrew/storage/kuzu_graph_store.py +++ b/motleycrew/storage/kuzu_graph_store.py @@ -10,11 +10,13 @@ import kuzu -class MotleyKuzuGraphStore(GraphStore): +class MotleyQuestionGraphStore(GraphStore): + IS_SUBQUESTION_PREDICATE = "IS_SUBQUESTION" + def __init__( self, database: Any, - node_table_name: str = "entity", + node_table_name: str = "question", rel_table_name: str = "links", **kwargs: Any, ) -> None: @@ -29,7 +31,7 @@ def init_schema(self) -> None: node_tables = self.connection._get_node_table_names() if self.node_table_name not in node_tables: self.connection.execute( - "CREATE NODE TABLE %s (ID STRING, PRIMARY KEY(ID))" + "CREATE NODE TABLE %s (ID SERIAL, question STRING, answer STRING, context STRING[], PRIMARY KEY(ID))" % self.node_table_name ) rel_tables = self.connection._get_rel_table_names() @@ -45,164 +47,103 @@ def init_schema(self) -> None: def client(self) -> Any: return self.connection - def get(self, subj: str) -> List[List[str]]: - """Get triplets.""" + def check_question_exists(self, question_id: int) -> bool: + is_exists_result = self.connection.execute( + "MATCH (n:%s) WHERE n.ID = $question_id RETURN n.ID" % self.node_table_name, + {"question_id": question_id}, + ) + return is_exists_result.has_next() + + def get_question(self, question_id: int) -> Optional[dict]: query = """ - MATCH (n1:%s)-[r:%s]->(n2:%s) - WHERE n1.ID = $subj - RETURN r.predicate, n2.ID; + MATCH (n1:%s) + WHERE n1.ID = $question_id + RETURN n1; """ + prepared_statement = self.connection.prepare(query % self.node_table_name) + query_result = self.connection.execute(prepared_statement, {"question_id": question_id}) + + if query_result.has_next(): + row = query_result.get_next() + return row[0] + + def get_subquestions(self, question_id: int) -> List[int]: + query = """ + MATCH (n1:%s)-[r:%s]->(n2:%s) + WHERE n1.ID = $question_id + AND r.predicate = $is_subquestion_predicate + RETURN n2.ID; + """ prepared_statement = self.connection.prepare( query % (self.node_table_name, self.rel_table_name, self.node_table_name) ) - query_result = self.connection.execute(prepared_statement, {"subj": subj}) + query_result = self.connection.execute( + prepared_statement, + { + "question_id": question_id, + "is_subquestion_predicate": MotleyQuestionGraphStore.IS_SUBQUESTION_PREDICATE, + }, + ) retval = [] while query_result.has_next(): row = query_result.get_next() - retval.append([row[0], row[1]]) + retval.append(row[0]) return retval - def get_rel_map( - self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30 - ) -> Dict[str, List[List[str]]]: - """Get depth-aware rel map.""" - rel_wildcard = "r:%s*1..%d" % (self.rel_table_name, depth) - match_clause = "MATCH (n1:{})-[{}]->(n2:{})".format( - self.node_table_name, - rel_wildcard, - self.node_table_name, + def create_question(self, question: str) -> int: + create_result = self.connection.execute( + "CREATE (n:%s {question: $question}) " "RETURN n.ID" % self.node_table_name, + {"question": question}, ) - return_clause = "RETURN n1, r, n2 LIMIT %d" % limit - params = [] - if subjs is not None: - for i, curr_subj in enumerate(subjs): - if i == 0: - where_clause = "WHERE n1.ID = $%d" % i - else: - where_clause += " OR n1.ID = $%d" % i - params.append((str(i), curr_subj)) - else: - where_clause = "" - query = f"{match_clause} {where_clause} {return_clause}" - prepared_statement = self.connection.prepare(query) - if subjs is not None: - query_result = self.connection.execute( - prepared_statement, {k: v for k, v in params} - ) - else: - query_result = self.connection.execute(prepared_statement) - retval: Dict[str, List[List[str]]] = {} - while query_result.has_next(): - row = query_result.get_next() - curr_path = [] - subj = row[0] - recursive_rel = row[1] - obj = row[2] - nodes_map = {} - nodes_map[(subj["_id"]["table"], subj["_id"]["offset"])] = subj["ID"] - nodes_map[(obj["_id"]["table"], obj["_id"]["offset"])] = obj["ID"] - for node in recursive_rel["_nodes"]: - nodes_map[(node["_id"]["table"], node["_id"]["offset"])] = node["ID"] - for rel in recursive_rel["_rels"]: - predicate = rel["predicate"] - curr_subj_id = nodes_map[(rel["_src"]["table"], rel["_src"]["offset"])] - curr_path.append(curr_subj_id) - curr_path.append(predicate) - # Add the last node - curr_path.append(obj["ID"]) - if subj["ID"] not in retval: - retval[subj["ID"]] = [] - retval[subj["ID"]].append(curr_path) - return retval - - def upsert_triplet(self, subj: str, rel: str, obj: str) -> None: - """Add triplet.""" - - def check_entity_exists(connection: Any, entity: str) -> bool: - is_exists_result = connection.execute( - "MATCH (n:%s) WHERE n.ID = $entity RETURN n.ID" % self.node_table_name, - {"entity": entity}, - ) - return is_exists_result.has_next() - - def create_entity(connection: Any, entity: str) -> None: - connection.execute( - "CREATE (n:%s {ID: $entity})" % self.node_table_name, - {"entity": entity}, - ) - - def check_rel_exists(connection: Any, subj: str, obj: str, rel: str) -> bool: - is_exists_result = connection.execute( - ( - "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID = " - "$obj AND r.predicate = $pred RETURN r.predicate" - ).format( - self.node_table_name, self.rel_table_name, self.node_table_name - ), - {"subj": subj, "obj": obj, "pred": rel}, - ) - return is_exists_result.has_next() + assert create_result.has_next() + return create_result.get_next()[0] - def create_rel(connection: Any, subj: str, obj: str, rel: str) -> None: + def create_subquestion(self, question_id: int, subquestion: str) -> int: + def create_subquestion_rel(connection: Any, question_id: int, subquestion_id: int) -> None: connection.execute( ( - "MATCH (n1:{}), (n2:{}) WHERE n1.ID = $subj AND n2.ID = $obj " - "CREATE (n1)-[r:{} {{predicate: $pred}}]->(n2)" - ).format( - self.node_table_name, self.node_table_name, self.rel_table_name - ), - {"subj": subj, "obj": obj, "pred": rel}, + "MATCH (n1:{}), (n2:{}) WHERE n1.ID = $question_id AND n2.ID = $subquestion_id " + "CREATE (n1)-[r:{} {{predicate: $is_subquestion_predicate}}]->(n2)" + ).format(self.node_table_name, self.node_table_name, self.rel_table_name), + { + "question_id": question_id, + "subquestion_id": subquestion_id, + "is_subquestion_predicate": MotleyQuestionGraphStore.IS_SUBQUESTION_PREDICATE, + }, ) - is_subj_exists = check_entity_exists(self.connection, subj) - is_obj_exists = check_entity_exists(self.connection, obj) - - if not is_subj_exists: - create_entity(self.connection, subj) - if not is_obj_exists: - create_entity(self.connection, obj) - - if is_subj_exists and is_obj_exists: - is_rel_exists = check_rel_exists(self.connection, subj, obj, rel) - if is_rel_exists: - return + if not self.check_question_exists(question_id): + raise Exception(f"No question with id {question_id}") - create_rel(self.connection, subj, obj, rel) + subquestion_id = self.create_question(subquestion) + create_subquestion_rel(self.connection, question_id=question_id, subquestion_id=subquestion_id) + return subquestion_id - def delete(self, subj: str, rel: str, obj: str) -> None: - """Delete triplet.""" + def delete_question(self, question_id: int) -> None: + """Deletes question and its relations.""" - def delete_rel(connection: Any, subj: str, obj: str, rel: str) -> None: + def delete_rels(connection: Any, question_id: int) -> None: connection.execute( - ( - "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID" - " = $obj AND r.predicate = $pred DELETE r" - ).format( + "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $question_id DELETE r".format( self.node_table_name, self.rel_table_name, self.node_table_name ), - {"subj": subj, "obj": obj, "pred": rel}, + {"question_id": question_id}, ) - - def delete_entity(connection: Any, entity: str) -> None: connection.execute( - "MATCH (n:%s) WHERE n.ID = $entity DELETE n" % self.node_table_name, - {"entity": entity}, - ) - - def check_edges(connection: Any, entity: str) -> bool: - is_exists_result = connection.execute( - "MATCH (n1:{})-[r:{}]-(n2:{}) WHERE n2.ID = $entity RETURN r.predicate".format( + "MATCH (n1:{})<-[r:{}]-(n2:{}) WHERE n1.ID = $question_id DELETE r".format( self.node_table_name, self.rel_table_name, self.node_table_name ), - {"entity": entity}, + {"question_id": question_id}, ) - return is_exists_result.has_next() - delete_rel(self.connection, subj, obj, rel) - if not check_edges(self.connection, subj): - delete_entity(self.connection, subj) - if not check_edges(self.connection, obj): - delete_entity(self.connection, obj) + def delete_question(connection: Any, question_id: int) -> None: + connection.execute( + "MATCH (n:%s) WHERE n.ID = $question_id DELETE n" % self.node_table_name, + {"question_id": question_id}, + ) + + delete_rels(self.connection, question_id) + delete_question(self.connection, question_id) @classmethod def from_persist_dir( @@ -210,7 +151,7 @@ def from_persist_dir( persist_dir: str, node_table_name: str = "entity", rel_table_name: str = "links", - ) -> "MotleyKuzuGraphStore": + ) -> "MotleyQuestionGraphStore": """Load from persist dir.""" try: import kuzu @@ -220,7 +161,7 @@ def from_persist_dir( return cls(database, node_table_name, rel_table_name) @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore": + def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyQuestionGraphStore": """Initialize graph store from configuration dictionary. Args: @@ -230,3 +171,30 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore": Graph store. """ return cls(**config_dict) + + +if __name__ == "__main__": + from pathlib import Path + import kuzu + + here = Path(__file__).parent + db_path = here / "test1" + db = kuzu.Database(str(db_path)) + graph_store = MotleyQuestionGraphStore(db) + + q1_id = graph_store.create_question("q1") + assert graph_store.get_question(q1_id)["question"] == "q1" + + q2_id = graph_store.create_subquestion(q1_id, "q2") + q3_id = graph_store.create_subquestion(q1_id, "q3") + q4_id = graph_store.create_subquestion(q3_id, "q4") + + assert set(graph_store.get_subquestions(q1_id)) == {q2_id, q3_id} + assert set(graph_store.get_subquestions(q3_id)) == {q4_id} + + graph_store.delete_question(q4_id) + assert graph_store.get_question(q4_id) is None + assert not graph_store.get_subquestions(q3_id) + + print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest") + print("MATCH (A)-[r]->(B) RETURN *;") From e6ff0a14ecbaa3b352ed9a6f9a3654e30b5d2f53 Mon Sep 17 00:00:00 2001 From: whimo Date: Thu, 25 Apr 2024 17:12:29 +0400 Subject: [PATCH 2/2] Remove redundant inheritance --- motleycrew/storage/kuzu_graph_store.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/motleycrew/storage/kuzu_graph_store.py b/motleycrew/storage/kuzu_graph_store.py index 3e0f7a24..e930c281 100644 --- a/motleycrew/storage/kuzu_graph_store.py +++ b/motleycrew/storage/kuzu_graph_store.py @@ -5,12 +5,10 @@ from typing import Any, Dict, List, Optional -from llama_index.core.graph_stores.types import GraphStore - import kuzu -class MotleyQuestionGraphStore(GraphStore): +class MotleyQuestionGraphStore: IS_SUBQUESTION_PREDICATE = "IS_SUBQUESTION" def __init__( @@ -175,7 +173,6 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyQuestionGraphStore": if __name__ == "__main__": from pathlib import Path - import kuzu here = Path(__file__).parent db_path = here / "test1"