Skip to content

Commit

Permalink
Merge branch 'research-agent' of https://github.com/ShoggothAI/motley…
Browse files Browse the repository at this point in the history
…crew into research_assistant
  • Loading branch information
ZmeiGorynych committed Apr 26, 2024
2 parents 5efcf7c + e6ff0a1 commit 43cd69c
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 138 deletions.
2 changes: 1 addition & 1 deletion motleycrew/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from kuzu_graph_store import MotleyKuzuGraphStore
from .kuzu_graph_store import MotleyQuestionGraphStore
239 changes: 102 additions & 137 deletions motleycrew/storage/kuzu_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

from typing import Any, Dict, List, Optional

from llama_index.core.graph_stores.types import GraphStore

import kuzu


class MotleyKuzuGraphStore(GraphStore):
class MotleyQuestionGraphStore:
IS_SUBQUESTION_PREDICATE = "IS_SUBQUESTION"

def __init__(
self,
database: Any,
node_table_name: str = "entity",
node_table_name: str = "question",
rel_table_name: str = "links",
**kwargs: Any,
) -> None:
Expand All @@ -29,7 +29,7 @@ def init_schema(self) -> None:
node_tables = self.connection._get_node_table_names()
if self.node_table_name not in node_tables:
self.connection.execute(
"CREATE NODE TABLE %s (ID STRING, PRIMARY KEY(ID))"
"CREATE NODE TABLE %s (ID SERIAL, question STRING, answer STRING, context STRING[], PRIMARY KEY(ID))"
% self.node_table_name
)
rel_tables = self.connection._get_rel_table_names()
Expand All @@ -45,172 +45,111 @@ def init_schema(self) -> None:
def client(self) -> Any:
return self.connection

def get(self, subj: str) -> List[List[str]]:
"""Get triplets."""
def check_question_exists(self, question_id: int) -> bool:
is_exists_result = self.connection.execute(
"MATCH (n:%s) WHERE n.ID = $question_id RETURN n.ID" % self.node_table_name,
{"question_id": question_id},
)
return is_exists_result.has_next()

def get_question(self, question_id: int) -> Optional[dict]:
query = """
MATCH (n1:%s)-[r:%s]->(n2:%s)
WHERE n1.ID = $subj
RETURN r.predicate, n2.ID;
MATCH (n1:%s)
WHERE n1.ID = $question_id
RETURN n1;
"""
prepared_statement = self.connection.prepare(query % self.node_table_name)
query_result = self.connection.execute(prepared_statement, {"question_id": question_id})

if query_result.has_next():
row = query_result.get_next()
return row[0]

def get_subquestions(self, question_id: int) -> List[int]:
query = """
MATCH (n1:%s)-[r:%s]->(n2:%s)
WHERE n1.ID = $question_id
AND r.predicate = $is_subquestion_predicate
RETURN n2.ID;
"""
prepared_statement = self.connection.prepare(
query % (self.node_table_name, self.rel_table_name, self.node_table_name)
)
query_result = self.connection.execute(prepared_statement, {"subj": subj})
query_result = self.connection.execute(
prepared_statement,
{
"question_id": question_id,
"is_subquestion_predicate": MotleyQuestionGraphStore.IS_SUBQUESTION_PREDICATE,
},
)
retval = []
while query_result.has_next():
row = query_result.get_next()
retval.append([row[0], row[1]])
retval.append(row[0])
return retval

def get_rel_map(
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
) -> Dict[str, List[List[str]]]:
"""Get depth-aware rel map."""
rel_wildcard = "r:%s*1..%d" % (self.rel_table_name, depth)
match_clause = "MATCH (n1:{})-[{}]->(n2:{})".format(
self.node_table_name,
rel_wildcard,
self.node_table_name,
def create_question(self, question: str) -> int:
create_result = self.connection.execute(
"CREATE (n:%s {question: $question}) " "RETURN n.ID" % self.node_table_name,
{"question": question},
)
return_clause = "RETURN n1, r, n2 LIMIT %d" % limit
params = []
if subjs is not None:
for i, curr_subj in enumerate(subjs):
if i == 0:
where_clause = "WHERE n1.ID = $%d" % i
else:
where_clause += " OR n1.ID = $%d" % i
params.append((str(i), curr_subj))
else:
where_clause = ""
query = f"{match_clause} {where_clause} {return_clause}"
prepared_statement = self.connection.prepare(query)
if subjs is not None:
query_result = self.connection.execute(
prepared_statement, {k: v for k, v in params}
)
else:
query_result = self.connection.execute(prepared_statement)
retval: Dict[str, List[List[str]]] = {}
while query_result.has_next():
row = query_result.get_next()
curr_path = []
subj = row[0]
recursive_rel = row[1]
obj = row[2]
nodes_map = {}
nodes_map[(subj["_id"]["table"], subj["_id"]["offset"])] = subj["ID"]
nodes_map[(obj["_id"]["table"], obj["_id"]["offset"])] = obj["ID"]
for node in recursive_rel["_nodes"]:
nodes_map[(node["_id"]["table"], node["_id"]["offset"])] = node["ID"]
for rel in recursive_rel["_rels"]:
predicate = rel["predicate"]
curr_subj_id = nodes_map[(rel["_src"]["table"], rel["_src"]["offset"])]
curr_path.append(curr_subj_id)
curr_path.append(predicate)
# Add the last node
curr_path.append(obj["ID"])
if subj["ID"] not in retval:
retval[subj["ID"]] = []
retval[subj["ID"]].append(curr_path)
return retval

def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""

def check_entity_exists(connection: Any, entity: str) -> bool:
is_exists_result = connection.execute(
"MATCH (n:%s) WHERE n.ID = $entity RETURN n.ID" % self.node_table_name,
{"entity": entity},
)
return is_exists_result.has_next()
assert create_result.has_next()
return create_result.get_next()[0]

def create_entity(connection: Any, entity: str) -> None:
def create_subquestion(self, question_id: int, subquestion: str) -> int:
def create_subquestion_rel(connection: Any, question_id: int, subquestion_id: int) -> None:
connection.execute(
"CREATE (n:%s {ID: $entity})" % self.node_table_name,
{"entity": entity},
)

def check_rel_exists(connection: Any, subj: str, obj: str, rel: str) -> bool:
is_exists_result = connection.execute(
(
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID = "
"$obj AND r.predicate = $pred RETURN r.predicate"
).format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
{"subj": subj, "obj": obj, "pred": rel},
"MATCH (n1:{}), (n2:{}) WHERE n1.ID = $question_id AND n2.ID = $subquestion_id "
"CREATE (n1)-[r:{} {{predicate: $is_subquestion_predicate}}]->(n2)"
).format(self.node_table_name, self.node_table_name, self.rel_table_name),
{
"question_id": question_id,
"subquestion_id": subquestion_id,
"is_subquestion_predicate": MotleyQuestionGraphStore.IS_SUBQUESTION_PREDICATE,
},
)
return is_exists_result.has_next()

def create_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
connection.execute(
(
"MATCH (n1:{}), (n2:{}) WHERE n1.ID = $subj AND n2.ID = $obj "
"CREATE (n1)-[r:{} {{predicate: $pred}}]->(n2)"
).format(
self.node_table_name, self.node_table_name, self.rel_table_name
),
{"subj": subj, "obj": obj, "pred": rel},
)

is_subj_exists = check_entity_exists(self.connection, subj)
is_obj_exists = check_entity_exists(self.connection, obj)

if not is_subj_exists:
create_entity(self.connection, subj)
if not is_obj_exists:
create_entity(self.connection, obj)

if is_subj_exists and is_obj_exists:
is_rel_exists = check_rel_exists(self.connection, subj, obj, rel)
if is_rel_exists:
return
if not self.check_question_exists(question_id):
raise Exception(f"No question with id {question_id}")

create_rel(self.connection, subj, obj, rel)
subquestion_id = self.create_question(subquestion)
create_subquestion_rel(self.connection, question_id=question_id, subquestion_id=subquestion_id)
return subquestion_id

def delete(self, subj: str, rel: str, obj: str) -> None:
"""Delete triplet."""
def delete_question(self, question_id: int) -> None:
"""Deletes question and its relations."""

def delete_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
def delete_rels(connection: Any, question_id: int) -> None:
connection.execute(
(
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID"
" = $obj AND r.predicate = $pred DELETE r"
).format(
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $question_id DELETE r".format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
{"subj": subj, "obj": obj, "pred": rel},
{"question_id": question_id},
)

def delete_entity(connection: Any, entity: str) -> None:
connection.execute(
"MATCH (n:%s) WHERE n.ID = $entity DELETE n" % self.node_table_name,
{"entity": entity},
)

def check_edges(connection: Any, entity: str) -> bool:
is_exists_result = connection.execute(
"MATCH (n1:{})-[r:{}]-(n2:{}) WHERE n2.ID = $entity RETURN r.predicate".format(
"MATCH (n1:{})<-[r:{}]-(n2:{}) WHERE n1.ID = $question_id DELETE r".format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
{"entity": entity},
{"question_id": question_id},
)
return is_exists_result.has_next()

delete_rel(self.connection, subj, obj, rel)
if not check_edges(self.connection, subj):
delete_entity(self.connection, subj)
if not check_edges(self.connection, obj):
delete_entity(self.connection, obj)
def delete_question(connection: Any, question_id: int) -> None:
connection.execute(
"MATCH (n:%s) WHERE n.ID = $question_id DELETE n" % self.node_table_name,
{"question_id": question_id},
)

delete_rels(self.connection, question_id)
delete_question(self.connection, question_id)

@classmethod
def from_persist_dir(
cls,
persist_dir: str,
node_table_name: str = "entity",
rel_table_name: str = "links",
) -> "MotleyKuzuGraphStore":
) -> "MotleyQuestionGraphStore":
"""Load from persist dir."""
try:
import kuzu
Expand All @@ -220,7 +159,7 @@ def from_persist_dir(
return cls(database, node_table_name, rel_table_name)

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore":
def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyQuestionGraphStore":
"""Initialize graph store from configuration dictionary.
Args:
Expand All @@ -230,3 +169,29 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore":
Graph store.
"""
return cls(**config_dict)


if __name__ == "__main__":
from pathlib import Path

here = Path(__file__).parent
db_path = here / "test1"
db = kuzu.Database(str(db_path))
graph_store = MotleyQuestionGraphStore(db)

q1_id = graph_store.create_question("q1")
assert graph_store.get_question(q1_id)["question"] == "q1"

q2_id = graph_store.create_subquestion(q1_id, "q2")
q3_id = graph_store.create_subquestion(q1_id, "q3")
q4_id = graph_store.create_subquestion(q3_id, "q4")

assert set(graph_store.get_subquestions(q1_id)) == {q2_id, q3_id}
assert set(graph_store.get_subquestions(q3_id)) == {q4_id}

graph_store.delete_question(q4_id)
assert graph_store.get_question(q4_id) is None
assert not graph_store.get_subquestions(q3_id)

print(f"docker run -p 8000:8000 -v {db_path}:/database --rm kuzudb/explorer: latest")
print("MATCH (A)-[r]->(B) RETURN *;")

0 comments on commit 43cd69c

Please sign in to comment.