From 4397baeacec18f9976215ea6c34c695303c94c0f Mon Sep 17 00:00:00 2001 From: Haojin <1454yhj@gmail.com> Date: Thu, 2 Jan 2025 19:56:17 +0800 Subject: [PATCH 1/2] feat(llm): Support generate gremlin based on regex --- .../common_op/regex_based_gremlin_generate.py | 112 ++++++++++++++++++ .../operators/gremlin_generate_task.py | 5 + .../operators/hugegraph_op/graph_rag_query.py | 58 +++++---- .../operators/llm_op/answer_synthesize.py | 65 +++++----- 4 files changed, 185 insertions(+), 55 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py new file mode 100644 index 00000000..296f5543 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional, List, Set, Tuple +import json +import jieba +import jieba.posseg as pseg +import re + +from hugegraph_llm.config import huge_settings +from hugegraph_llm.utils.log import log +from pyhugegraph.client import PyHugeClient + + +class RegexGremlinGenerate: + def __init__(self): + self._client = PyHugeClient( + huge_settings.graph_ip, + huge_settings.graph_port, + huge_settings.graph_name, + huge_settings.graph_user, + huge_settings.graph_pwd, + huge_settings.graph_space, + ) + + def _remove_nouns(self, text): + text = re.sub(r"[^\w\u4e00-\u9fa5]", "", text) + words = list(pseg.cut(text)) + non_noun_words = [word for word, flag in words if not flag.startswith("n")] + noun_words = [word for word, flag in words if flag.startswith("n")] + return non_noun_words, noun_words + + def _compare_non_noun_parts(self, source, target): + source_filtered, source_nouns = self._remove_nouns(source) + target_filtered, target_nouns = self._remove_nouns(target) + + if source_filtered == target_filtered: + log.debug("The non-noun parts of the texts are identical.") + return len(source_nouns) == len(target_nouns), source_nouns, target_nouns + + log.debug("The non-noun parts of the texts are different") + log.debug(f"Source non noun words are: {source_filtered} & Targets' are: {target_filtered}") + + return False, [], [] + + def _replace_gremlin_query(self, gremlin, source_filtered, target_filtered): + for source_str, target_str in zip(source_filtered, target_filtered): + gremlin = gremlin.replace(source_str, target_str) + return gremlin + + def _generate_gremlin(self, query, examples): + for example in examples: + example_query = example["query"] + example_gremlin = example["gremlin"] + + is_match, source, target = self._compare_non_noun_parts(example_query, query) + if is_match: + gremlin = self._replace_gremlin_query(example_gremlin, source, target) + log.debug(f"Gremlin generated by regex is : {gremlin}") + + try: + result = self._client.gremlin().exec(gremlin=gremlin)["data"] + log.debug(f"Gremlin generated by regex execution res is : {result}") + return True, [json.dumps(item, ensure_ascii=False) for item in result] + except Exception as e: + return False, None + + return False, None + + def _init_client(self, context): + if not self._client: + graph_client = context.get("graph_client") + if isinstance(graph_client, PyHugeClient): + self._client = graph_client + else: + self._client = PyHugeClient( + context.get("ip", "localhost"), + context.get("port", "8080"), + context.get("graph", "hugegraph"), + context.get("user", "admin"), + context.get("pwd", "admin"), + context.get("graphspace"), + ) + assert self._client, "No valid graphdb client" + + def run(self, context: Dict[str, Any]) -> Dict[str, Any]: + self._init_client(context) + + query = context.get("query") + if not query: + raise ValueError("Query is required") + + examples = context.get("match_result") + if not examples: + context["skip_llm_gremlin"] = False + return context + + context["skip_llm_gremlin"], context["graph_result"] = self._generate_gremlin(query, examples) + return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py index 95ce59f0..111916db 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py @@ -20,6 +20,7 @@ from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.operators.common_op.check_schema import CheckSchema from hugegraph_llm.operators.common_op.print_result import PrintResult +from hugegraph_llm.operators.common_op.regex_based_gremlin_generate import RegexGremlinGenerate from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.operators.index_op.build_gremlin_example_index import BuildGremlinExampleIndex from hugegraph_llm.operators.index_op.gremlin_example_index_query import GremlinExampleIndexQuery @@ -58,6 +59,10 @@ def example_index_query(self, num_examples): self.operators.append(GremlinExampleIndexQuery(self.embedding, num_examples)) return self + def regex_gremlin_generate(self): + self.operators.append(RegexGremlinGenerate()) + return self + def gremlin_generate_synthesize( self, schema, gremlin_prompt: Optional[str] = None, vertices: Optional[List[str]] = None ): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index e213c372..9c24b2a4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -132,29 +132,43 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: query_embedding = context.get("query_embedding") self._gremlin_generator.clear() - self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example) - gremlin_response = self._gremlin_generator.gremlin_generate_synthesize( - context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt - ).run(query=query, query_embedding=query_embedding) - if self._num_gremlin_generate_example > 0: - gremlin = gremlin_response["result"] + regex_res = ( + self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example) + .regex_gremlin_generate() + .run(query=query, query_embedding=query_embedding) + ) + + log.debug("Skip llm gremlin is %s", regex_res["skip_llm_gremlin"]) + if not regex_res["skip_llm_gremlin"]: + self._gremlin_generator.clear() + self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example) + gremlin_response = self._gremlin_generator.gremlin_generate_synthesize( + context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt + ).run(query=query, query_embedding=query_embedding) + if self._num_gremlin_generate_example > 0: + gremlin = gremlin_response["result"] + else: + gremlin = gremlin_response["raw_result"] + log.info("Generated gremlin: %s", gremlin) + context["gremlin"] = gremlin + try: + result = self._client.gremlin().exec(gremlin=gremlin)["data"] + if result == [None]: + result = [] + context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result] + if context["graph_result"]: + context["graph_result_flag"] = 1 + context["graph_context_head"] = ( + f"The following are graph query result " f"from gremlin query `{gremlin}`.\n" + ) + except Exception as e: # pylint: disable=broad-except + log.error(e) + context["graph_result"] = "" else: - gremlin = gremlin_response["raw_result"] - log.info("Generated gremlin: %s", gremlin) - context["gremlin"] = gremlin - try: - result = self._client.gremlin().exec(gremlin=gremlin)["data"] - if result == [None]: - result = [] - context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result] - if context["graph_result"]: - context["graph_result_flag"] = 1 - context["graph_context_head"] = ( - f"The following are graph query result " f"from gremlin query `{gremlin}`.\n" - ) - except Exception as e: # pylint: disable=broad-except - log.error(e) - context["graph_result"] = "" + context["graph_result"] = regex_res["graph_result"] + if context["graph_result"]: + context["graph_result_flag"] = 1 + context["graph_context_head"] = f"The following are graph query result " f"from regex gremlin.\n" return context def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 666ecf99..ee6e07a8 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -35,17 +35,17 @@ class AnswerSynthesize: def __init__( - self, - llm: Optional[BaseLLM] = None, - prompt_template: Optional[str] = None, - question: Optional[str] = None, - context_body: Optional[str] = None, - context_head: Optional[str] = None, - context_tail: Optional[str] = None, - raw_answer: bool = False, - vector_only_answer: bool = True, - graph_only_answer: bool = False, - graph_vector_answer: bool = False, + self, + llm: Optional[BaseLLM] = None, + prompt_template: Optional[str] = None, + question: Optional[str] = None, + context_body: Optional[str] = None, + context_head: Optional[str] = None, + context_tail: Optional[str] = None, + raw_answer: bool = False, + vector_only_answer: bool = True, + graph_only_answer: bool = False, + graph_vector_answer: bool = False, ): self._llm = llm self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE @@ -70,9 +70,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context_tail_str = context.get("synthesize_context_tail") or self._context_tail or "" if self._context_body is not None: - context_str = (f"{context_head_str}\n" - f"{self._context_body}\n" - f"{context_tail_str}".strip("\n")) + context_str = f"{context_head_str}\n" f"{self._context_body}\n" f"{context_tail_str}".strip("\n") final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) response = self._llm.generate(prompt=final_prompt) @@ -96,50 +94,51 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: graph_result_context = "No related graph data found for current query." log.warning(graph_result_context) - context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, - vector_result_context, graph_result_context)) + context = asyncio.run( + self.async_generate( + context, context_head_str, context_tail_str, vector_result_context, graph_result_context + ) + ) return context - async def async_generate(self, context: Dict[str, Any], context_head_str: str, - context_tail_str: str, vector_result_context: str, - graph_result_context: str): + async def async_generate( + self, + context: Dict[str, Any], + context_head_str: str, + context_tail_str: str, + vector_result_context: str, + graph_result_context: str, + ): # async_tasks stores the async tasks for different answer types async_tasks = {} if self._raw_answer: final_prompt = self._question async_tasks["raw_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) if self._vector_only_answer: - context_str = (f"{context_head_str}\n" - f"{vector_result_context}\n" - f"{context_tail_str}".strip("\n")) + context_str = f"{context_head_str}\n" f"{vector_result_context}\n" f"{context_tail_str}".strip("\n") final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) async_tasks["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) if self._graph_only_answer: - context_str = (f"{context_head_str}\n" - f"{graph_result_context}\n" - f"{context_tail_str}".strip("\n")) + context_str = f"{context_head_str}\n" f"{graph_result_context}\n" f"{context_tail_str}".strip("\n") final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + log.debug("Final Prompt is %s", final_prompt) async_tasks["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" if context.get("graph_ratio", 0.5) < 0.5: context_body_str = f"{graph_result_context}\n{vector_result_context}" - context_str = (f"{context_head_str}\n" - f"{context_body_str}\n" - f"{context_tail_str}".strip("\n")) + context_str = f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}".strip("\n") final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) - async_tasks["graph_vector_task"] = asyncio.create_task( - self._llm.agenerate(prompt=final_prompt) - ) + async_tasks["graph_vector_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) async_tasks_mapping = { "raw_task": "raw_answer", "vector_only_task": "vector_only_answer", "graph_only_task": "graph_only_answer", - "graph_vector_task": "graph_vector_answer" + "graph_vector_task": "graph_vector_answer", } for task_key, context_key in async_tasks_mapping.items(): @@ -149,5 +148,5 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, log.debug("Query Answer: %s", response) ops = sum([self._raw_answer, self._vector_only_answer, self._graph_only_answer, self._graph_vector_answer]) - context['call_count'] = context.get('call_count', 0) + ops + context["call_count"] = context.get("call_count", 0) + ops return context From 333d5804f676f29c0153145ca3e7b7acb1d4a62d Mon Sep 17 00:00:00 2001 From: Haojin <1454yhj@gmail.com> Date: Thu, 2 Jan 2025 20:15:51 +0800 Subject: [PATCH 2/2] fix(llm): fix pylint --- .../common_op/regex_based_gremlin_generate.py | 14 +++++++------- .../operators/hugegraph_op/graph_rag_query.py | 5 +---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py index 296f5543..c5eae64a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/regex_based_gremlin_generate.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional, List, Set, Tuple +from typing import Any, Dict import json -import jieba -import jieba.posseg as pseg import re +import jieba.posseg as pseg from hugegraph_llm.config import huge_settings from hugegraph_llm.utils.log import log @@ -52,7 +51,7 @@ def _compare_non_noun_parts(self, source, target): return len(source_nouns) == len(target_nouns), source_nouns, target_nouns log.debug("The non-noun parts of the texts are different") - log.debug(f"Source non noun words are: {source_filtered} & Targets' are: {target_filtered}") + log.debug("Source non noun words are: %s & Targets' are: %s", source_filtered, target_filtered) return False, [], [] @@ -69,13 +68,14 @@ def _generate_gremlin(self, query, examples): is_match, source, target = self._compare_non_noun_parts(example_query, query) if is_match: gremlin = self._replace_gremlin_query(example_gremlin, source, target) - log.debug(f"Gremlin generated by regex is : {gremlin}") + log.debug("Gremlin generated by regex is : %s", gremlin) try: result = self._client.gremlin().exec(gremlin=gremlin)["data"] - log.debug(f"Gremlin generated by regex execution res is : {result}") + log.debug("Gremlin generated by regex execution res is : %s", result) return True, [json.dumps(item, ensure_ascii=False) for item in result] - except Exception as e: + except Exception as e: # pylint: disable=broad-except + log.debug("Error %s happened executing gremlin query", e) return False, None return False, None diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 9c24b2a4..e8a05dbe 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -137,8 +137,6 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: .regex_gremlin_generate() .run(query=query, query_embedding=query_embedding) ) - - log.debug("Skip llm gremlin is %s", regex_res["skip_llm_gremlin"]) if not regex_res["skip_llm_gremlin"]: self._gremlin_generator.clear() self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example) @@ -159,7 +157,7 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: if context["graph_result"]: context["graph_result_flag"] = 1 context["graph_context_head"] = ( - f"The following are graph query result " f"from gremlin query `{gremlin}`.\n" + "The following are graph query result " f"from gremlin query `{gremlin}`.\n" ) except Exception as e: # pylint: disable=broad-except log.error(e) @@ -168,7 +166,6 @@ def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: context["graph_result"] = regex_res["graph_result"] if context["graph_result"]: context["graph_result_flag"] = 1 - context["graph_context_head"] = f"The following are graph query result " f"from regex gremlin.\n" return context def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: