Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): Support generate gremlin based on regex #152

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
import json
import re
import jieba.posseg as pseg

from hugegraph_llm.config import huge_settings
from hugegraph_llm.utils.log import log
from pyhugegraph.client import PyHugeClient


class RegexGremlinGenerate:
def __init__(self):
self._client = PyHugeClient(
huge_settings.graph_ip,
huge_settings.graph_port,
huge_settings.graph_name,
huge_settings.graph_user,
huge_settings.graph_pwd,
huge_settings.graph_space,
)

def _remove_nouns(self, text):
text = re.sub(r"[^\w\u4e00-\u9fa5]", "", text)
words = list(pseg.cut(text))
non_noun_words = [word for word, flag in words if not flag.startswith("n")]
noun_words = [word for word, flag in words if flag.startswith("n")]
return non_noun_words, noun_words

def _compare_non_noun_parts(self, source, target):
source_filtered, source_nouns = self._remove_nouns(source)
target_filtered, target_nouns = self._remove_nouns(target)

if source_filtered == target_filtered:
log.debug("The non-noun parts of the texts are identical.")
return len(source_nouns) == len(target_nouns), source_nouns, target_nouns

log.debug("The non-noun parts of the texts are different")
log.debug("Source non noun words are: %s & Targets' are: %s", source_filtered, target_filtered)

return False, [], []

def _replace_gremlin_query(self, gremlin, source_filtered, target_filtered):
for source_str, target_str in zip(source_filtered, target_filtered):
gremlin = gremlin.replace(source_str, target_str)
return gremlin

def _generate_gremlin(self, query, examples):
for example in examples:
example_query = example["query"]
example_gremlin = example["gremlin"]

is_match, source, target = self._compare_non_noun_parts(example_query, query)
if is_match:
gremlin = self._replace_gremlin_query(example_gremlin, source, target)
log.debug("Gremlin generated by regex is : %s", gremlin)

try:
result = self._client.gremlin().exec(gremlin=gremlin)["data"]
log.debug("Gremlin generated by regex execution res is : %s", result)
return True, [json.dumps(item, ensure_ascii=False) for item in result]
except Exception as e: # pylint: disable=broad-except
log.debug("Error %s happened executing gremlin query", e)
return False, None

return False, None

def _init_client(self, context):
if not self._client:
graph_client = context.get("graph_client")
if isinstance(graph_client, PyHugeClient):
self._client = graph_client
else:
self._client = PyHugeClient(
context.get("ip", "localhost"),
context.get("port", "8080"),
context.get("graph", "hugegraph"),
context.get("user", "admin"),
context.get("pwd", "admin"),
context.get("graphspace"),
)
assert self._client, "No valid graphdb client"

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
self._init_client(context)

query = context.get("query")
if not query:
raise ValueError("Query is required")

examples = context.get("match_result")
if not examples:
context["skip_llm_gremlin"] = False
return context

context["skip_llm_gremlin"], context["graph_result"] = self._generate_gremlin(query, examples)
return context
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.operators.common_op.check_schema import CheckSchema
from hugegraph_llm.operators.common_op.print_result import PrintResult
from hugegraph_llm.operators.common_op.regex_based_gremlin_generate import RegexGremlinGenerate
from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
from hugegraph_llm.operators.index_op.build_gremlin_example_index import BuildGremlinExampleIndex
from hugegraph_llm.operators.index_op.gremlin_example_index_query import GremlinExampleIndexQuery
Expand Down Expand Up @@ -58,6 +59,10 @@ def example_index_query(self, num_examples):
self.operators.append(GremlinExampleIndexQuery(self.embedding, num_examples))
return self

def regex_gremlin_generate(self):
self.operators.append(RegexGremlinGenerate())
return self

def gremlin_generate_synthesize(
self, schema, gremlin_prompt: Optional[str] = None, vertices: Optional[List[str]] = None
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,40 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]:
query_embedding = context.get("query_embedding")

self._gremlin_generator.clear()
self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example)
gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt
).run(query=query, query_embedding=query_embedding)
if self._num_gremlin_generate_example > 0:
gremlin = gremlin_response["result"]
regex_res = (
self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example)
.regex_gremlin_generate()
.run(query=query, query_embedding=query_embedding)
)
if not regex_res["skip_llm_gremlin"]:
self._gremlin_generator.clear()
self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example)
gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt
).run(query=query, query_embedding=query_embedding)
if self._num_gremlin_generate_example > 0:
gremlin = gremlin_response["result"]
else:
gremlin = gremlin_response["raw_result"]
log.info("Generated gremlin: %s", gremlin)
context["gremlin"] = gremlin
try:
result = self._client.gremlin().exec(gremlin=gremlin)["data"]
if result == [None]:
result = []
context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result]
if context["graph_result"]:
context["graph_result_flag"] = 1
context["graph_context_head"] = (
"The following are graph query result " f"from gremlin query `{gremlin}`.\n"
)
except Exception as e: # pylint: disable=broad-except
log.error(e)
context["graph_result"] = ""
else:
gremlin = gremlin_response["raw_result"]
log.info("Generated gremlin: %s", gremlin)
context["gremlin"] = gremlin
try:
result = self._client.gremlin().exec(gremlin=gremlin)["data"]
if result == [None]:
result = []
context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result]
if context["graph_result"]:
context["graph_result_flag"] = 1
context["graph_context_head"] = (
f"The following are graph query result " f"from gremlin query `{gremlin}`.\n"
)
except Exception as e: # pylint: disable=broad-except
log.error(e)
context["graph_result"] = ""
context["graph_result"] = regex_res["graph_result"]
if context["graph_result"]:
context["graph_result_flag"] = 1
return context

def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@

class AnswerSynthesize:
def __init__(
self,
llm: Optional[BaseLLM] = None,
prompt_template: Optional[str] = None,
question: Optional[str] = None,
context_body: Optional[str] = None,
context_head: Optional[str] = None,
context_tail: Optional[str] = None,
raw_answer: bool = False,
vector_only_answer: bool = True,
graph_only_answer: bool = False,
graph_vector_answer: bool = False,
self,
llm: Optional[BaseLLM] = None,
prompt_template: Optional[str] = None,
question: Optional[str] = None,
context_body: Optional[str] = None,
context_head: Optional[str] = None,
context_tail: Optional[str] = None,
raw_answer: bool = False,
vector_only_answer: bool = True,
graph_only_answer: bool = False,
graph_vector_answer: bool = False,
):
self._llm = llm
self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE
Expand All @@ -70,9 +70,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
context_tail_str = context.get("synthesize_context_tail") or self._context_tail or ""

if self._context_body is not None:
context_str = (f"{context_head_str}\n"
f"{self._context_body}\n"
f"{context_tail_str}".strip("\n"))
context_str = f"{context_head_str}\n" f"{self._context_body}\n" f"{context_tail_str}".strip("\n")

final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
response = self._llm.generate(prompt=final_prompt)
Expand All @@ -96,50 +94,51 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
graph_result_context = "No related graph data found for current query."
log.warning(graph_result_context)

context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str,
vector_result_context, graph_result_context))
context = asyncio.run(
self.async_generate(
context, context_head_str, context_tail_str, vector_result_context, graph_result_context
)
)
return context

async def async_generate(self, context: Dict[str, Any], context_head_str: str,
context_tail_str: str, vector_result_context: str,
graph_result_context: str):
async def async_generate(
self,
context: Dict[str, Any],
context_head_str: str,
context_tail_str: str,
vector_result_context: str,
graph_result_context: str,
):
# async_tasks stores the async tasks for different answer types
async_tasks = {}
if self._raw_answer:
final_prompt = self._question
async_tasks["raw_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
if self._vector_only_answer:
context_str = (f"{context_head_str}\n"
f"{vector_result_context}\n"
f"{context_tail_str}".strip("\n"))
context_str = f"{context_head_str}\n" f"{vector_result_context}\n" f"{context_tail_str}".strip("\n")

final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
async_tasks["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
if self._graph_only_answer:
context_str = (f"{context_head_str}\n"
f"{graph_result_context}\n"
f"{context_tail_str}".strip("\n"))
context_str = f"{context_head_str}\n" f"{graph_result_context}\n" f"{context_tail_str}".strip("\n")

final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
log.debug("Final Prompt is %s", final_prompt)
async_tasks["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt))
if self._graph_vector_answer:
context_body_str = f"{vector_result_context}\n{graph_result_context}"
if context.get("graph_ratio", 0.5) < 0.5:
context_body_str = f"{graph_result_context}\n{vector_result_context}"
context_str = (f"{context_head_str}\n"
f"{context_body_str}\n"
f"{context_tail_str}".strip("\n"))
context_str = f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}".strip("\n")

final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
async_tasks["graph_vector_task"] = asyncio.create_task(
self._llm.agenerate(prompt=final_prompt)
)
async_tasks["graph_vector_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt))

async_tasks_mapping = {
"raw_task": "raw_answer",
"vector_only_task": "vector_only_answer",
"graph_only_task": "graph_only_answer",
"graph_vector_task": "graph_vector_answer"
"graph_vector_task": "graph_vector_answer",
}

for task_key, context_key in async_tasks_mapping.items():
Expand All @@ -149,5 +148,5 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str,
log.debug("Query Answer: %s", response)

ops = sum([self._raw_answer, self._vector_only_answer, self._graph_only_answer, self._graph_vector_answer])
context['call_count'] = context.get('call_count', 0) + ops
context["call_count"] = context.get("call_count", 0) + ops
return context
Loading