From 75071a493e10ca6edc6c11b44325f36887e10417 Mon Sep 17 00:00:00 2001 From: whimo Date: Fri, 26 Apr 2024 13:27:53 +0400 Subject: [PATCH] Kuzu graph store entity creation fix --- motleycrew/storage/kuzu_graph_store.py | 32 +++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/motleycrew/storage/kuzu_graph_store.py b/motleycrew/storage/kuzu_graph_store.py index d2831b67..96b9f48f 100644 --- a/motleycrew/storage/kuzu_graph_store.py +++ b/motleycrew/storage/kuzu_graph_store.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional -import json import kuzu @@ -70,11 +69,23 @@ def get_entity(self, entity_id: int) -> Optional[dict]: item = row[0] return item + @staticmethod + def _dict_to_cypher_mapping_with_parameters(entity: dict) -> tuple[str, dict]: + parameters = {} + + cypher_mapping = "{" + for key, value in entity.items(): + cypher_mapping += f"{key}: ${key}, " + parameters[key] = value + + cypher_mapping = cypher_mapping.rstrip(", ") + "}" + return cypher_mapping, parameters + def create_entity(self, entity: dict) -> int: """Create a new entity and return its id""" + cypher_mapping, parameters = MotleyKuzuGraphStore._dict_to_cypher_mapping_with_parameters(entity) create_result = self.connection.execute( - "CREATE (n:{} $entity) RETURN n.id".format(self.node_table_name), - {"entity": entity}, + "CREATE (n:{} {}) RETURN n.id".format(self.node_table_name, cypher_mapping), parameters=parameters ) assert create_result.has_next() return create_result.get_next()[0] @@ -98,14 +109,8 @@ def delete_entity(self, entity_id: int) -> None: def delete_rels(connection: Any, entity_id: int) -> None: # Undirected relation removal is not supported for some reason connection.execute( - "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.id = $entity_id DELETE r;" - "MATCH (n1:{})<-[r:{}]-(n2:{}) WHERE n1.id = $entity_id DELETE r".format( - self.node_table_name, - self.rel_table_name, - self.node_table_name, - self.node_table_name, - self.rel_table_name, - self.node_table_name, + "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.id = $entity_id DELETE r".format( + self.node_table_name, self.rel_table_name, self.node_table_name ), {"entity_id": entity_id}, ) @@ -174,6 +179,7 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore": if __name__ == "__main__": from pathlib import Path import shutil + import json here = Path(__file__).parent db_path = here / "test1" @@ -199,10 +205,10 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore": assert graph_store.get_entity(q4_id) is None graph_store.set_property(q2_id, property_name="answer", property_value="a2") - graph_store.set_property(q3_id, property_name="", property_value=["c3_1", "c3_2"]) + graph_store.set_property(q3_id, property_name="context", property_value=json.dumps(["c3_1", "c3_2"])) assert graph_store.get_entity(q2_id)["answer"] == "a2" - assert graph_store.get_entity(q3_id)["context"] == ["c3_1", "c3_2"] + assert json.loads(graph_store.get_entity(q3_id)["context"]) == ["c3_1", "c3_2"] print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest") print("MATCH (A)-[r]->(B) RETURN *;")