Skip to content

Commit

Permalink
Kuzu graph store entity creation fix
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Apr 26, 2024
1 parent 068f5f4 commit 75071a4
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions motleycrew/storage/kuzu_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Any, Dict, List, Optional

import json
import kuzu


Expand Down Expand Up @@ -70,11 +69,23 @@ def get_entity(self, entity_id: int) -> Optional[dict]:
item = row[0]
return item

@staticmethod
def _dict_to_cypher_mapping_with_parameters(entity: dict) -> tuple[str, dict]:
parameters = {}

cypher_mapping = "{"
for key, value in entity.items():
cypher_mapping += f"{key}: ${key}, "
parameters[key] = value

cypher_mapping = cypher_mapping.rstrip(", ") + "}"
return cypher_mapping, parameters

def create_entity(self, entity: dict) -> int:
"""Create a new entity and return its id"""
cypher_mapping, parameters = MotleyKuzuGraphStore._dict_to_cypher_mapping_with_parameters(entity)
create_result = self.connection.execute(
"CREATE (n:{} $entity) RETURN n.id".format(self.node_table_name),
{"entity": entity},
"CREATE (n:{} {}) RETURN n.id".format(self.node_table_name, cypher_mapping), parameters=parameters
)
assert create_result.has_next()
return create_result.get_next()[0]
Expand All @@ -98,14 +109,8 @@ def delete_entity(self, entity_id: int) -> None:
def delete_rels(connection: Any, entity_id: int) -> None:
# Undirected relation removal is not supported for some reason
connection.execute(
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.id = $entity_id DELETE r;"
"MATCH (n1:{})<-[r:{}]-(n2:{}) WHERE n1.id = $entity_id DELETE r".format(
self.node_table_name,
self.rel_table_name,
self.node_table_name,
self.node_table_name,
self.rel_table_name,
self.node_table_name,
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.id = $entity_id DELETE r".format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
{"entity_id": entity_id},
)
Expand Down Expand Up @@ -174,6 +179,7 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore":
if __name__ == "__main__":
from pathlib import Path
import shutil
import json

here = Path(__file__).parent
db_path = here / "test1"
Expand All @@ -199,10 +205,10 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore":
assert graph_store.get_entity(q4_id) is None

graph_store.set_property(q2_id, property_name="answer", property_value="a2")
graph_store.set_property(q3_id, property_name="", property_value=["c3_1", "c3_2"])
graph_store.set_property(q3_id, property_name="context", property_value=json.dumps(["c3_1", "c3_2"]))

assert graph_store.get_entity(q2_id)["answer"] == "a2"
assert graph_store.get_entity(q3_id)["context"] == ["c3_1", "c3_2"]
assert json.loads(graph_store.get_entity(q3_id)["context"]) == ["c3_1", "c3_2"]

print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest")
print("MATCH (A)-[r]->(B) RETURN *;")

0 comments on commit 75071a4

Please sign in to comment.