Skip to content

Commit

Permalink
refactor(llm): merge & update keyword extraction logic/prompt (#120)
Browse files Browse the repository at this point in the history
* update english prompt

---------

Co-authored-by: imbajin <[email protected]>
  • Loading branch information
HJ-Young and imbajin authored Nov 27, 2024
1 parent 2d8e6c8 commit ef56263
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 70 deletions.
6 changes: 6 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def save_to_yaml(self):
"\n".join([f" {line}" for line in self.custom_rerank_info.splitlines()])
)
indented_default_answer_template = "\n".join([f" {line}" for line in self.answer_prompt.splitlines()])
indented_keywords_extract_template = (
"\n".join([f" {line}" for line in self.keywords_extract_prompt.splitlines()])
)

# This can be extended to add storage fields according to the data needs to be stored
yaml_content = f"""graph_schema: |
Expand All @@ -141,6 +144,9 @@ def save_to_yaml(self):
answer_prompt: |
{indented_default_answer_template}
keywords_extract_prompt: |
{indented_keywords_extract_template}
"""
with open(yaml_file_path, "w", encoding="utf-8") as file:
file.write(yaml_content)
Expand Down
48 changes: 48 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/config/config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,51 @@ class PromptData:
]
}
"""

# Extracted from llm_op/keyword_extract.py
keywords_extract_prompt = """指令:
请对以下文本执行以下任务:
1. 从文本中提取关键词:
- 最少 0 个,最多 {max_keywords} 个。
- 关键词应为具有完整语义的词语或短语,确保信息完整。
2. 识别需改写的关键词:
- 从提取的关键词中,识别那些在原语境中具有歧义或存在信息缺失的关键词。
3. 生成同义词:
- 对这些需改写的关键词,生成其在给定语境下的同义词或含义相近的词语。
- 使用生成的同义词替换原文中的相应关键词。
- 如果某个关键词没有合适的同义词,则保留该关键词不变。
要求:
- 关键词应为有意义且具体的实体,避免使用无意义或过于宽泛的词语,或单字符的词(例如:“物品”、“动作”、“效果”、“作用”、“的”、“他”)。
- 优先提取主语、动词和宾语,避免提取虚词或助词。
- 保持语义完整性: 抽取的关键词应尽量保持关键词在原语境中语义和信息的完整性(例如:“苹果电脑”应作为一个整体被抽取,而不是被分为“苹果”和“电脑”)。
- 避免泛化: 不要扩展为不相关的泛化类别。
注意:
- 仅考虑语境相关的同义词: 只需考虑给定语境下的关键词的语义近义词和具有类似含义的其他词语。
- 调整关键词长度: 如果关键词相对宽泛,可以根据语境适当增加单个关键词的长度(例如:“违法行为”可以作为一个单独的关键词被抽取,或抽取为“违法”,但不应拆分为“违法”和“行为”)。
输出格式:
- 仅输出一行内容, 以 KEYWORDS: 为前缀,后跟所有关键词或对应的同义词,之间用逗号分隔。抽取的关键词中不允许出现空格或空字符
- 格式示例:
KEYWORDS:关键词1,关键词2,...,关键词n
文本:
{question}
"""

# keywords_extract_prompt_EN = """
# Instruction:
# Please perform the following tasks on the text below:
# 1. Extract Keywords and Generate Synonyms from text:
# - At least 0, at most {max_keywords} keywords.
# - For each keyword, generate its synonyms or possible variant forms.
# Requirements:
# - Keywords should be meaningful and specific entities; avoid using meaningless or overly broad terms (e.g., “object,” “the,” “he”).
# - Prioritize extracting subjects, verbs, and objects; avoid extracting function words or auxiliary words.
# - Do not expand into unrelated generalized categories.
# Note:
# - Only consider semantic synonyms and other words with similar meanings in the given context.
# Output Format:
# - Output only one line, prefixed with KEYWORDS:, followed by all keywords and synonyms, separated by commas.No spaces or empty characters are allowed in the extracted keywords.
# - Format example:
# KEYWORDS: keyword1, keyword2, ..., keywordn, synonym1, synonym2, ..., synonymn
# Text:
# {question}
# """
7 changes: 4 additions & 3 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def init_rag_ui() -> gr.Interface:
with gr.Tab(label="1. Build RAG Index 💡"):
textbox_input_schema, textbox_info_extract_template = create_vector_graph_block()
with gr.Tab(label="2. (Graph)RAG & User Functions 📖"):
textbox_inp, textbox_answer_prompt_input = create_rag_block()
textbox_inp, textbox_answer_prompt_input, textbox_keywords_extract_prompt_input = create_rag_block()
with gr.Tab(label="3. Graph Tools 🚧"):
create_other_block()
with gr.Tab(label="4. Admin Tools ⚙️"):
Expand All @@ -105,7 +105,7 @@ def refresh_ui_config_prompt() -> tuple:
return (
settings.graph_ip, settings.graph_port, settings.graph_name, settings.graph_user,
settings.graph_pwd, settings.graph_space, prompt.graph_schema, prompt.extract_graph_prompt,
prompt.default_question, prompt.answer_prompt
prompt.default_question, prompt.answer_prompt, prompt.keywords_extract_prompt
)

hugegraph_llm_ui.load(fn=refresh_ui_config_prompt, outputs=[
Expand All @@ -120,7 +120,8 @@ def refresh_ui_config_prompt() -> tuple:
textbox_info_extract_template,

textbox_inp,
textbox_answer_prompt_input
textbox_answer_prompt_input,
textbox_keywords_extract_prompt_input
])

return hugegraph_llm_ui
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def chat_llm_settings(llm_type):
llm_config_button = gr.Button("Apply configuration")
llm_config_button.click(apply_llm_config_with_chat_op, inputs=llm_config_input)

with gr.Tab(label='extract'):
with gr.Tab(label='mini_tasks'):
extract_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"],
value=getattr(settings, f"extract_llm_type"), label=f"type")
apply_llm_config_with_extract_op = partial(apply_llm_config, "extract")
Expand Down
19 changes: 14 additions & 5 deletions hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def rag_answer(
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
) -> Tuple:
"""
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
Expand All @@ -49,11 +50,12 @@ def rag_answer(
4. Synthesize the final answer.
5. Run the pipeline and return the results.
"""
should_update_prompt = prompt.default_question != text or prompt.answer_prompt != answer_prompt
should_update_prompt = prompt.default_question != text or prompt.answer_prompt != answer_prompt or prompt.keywords_extract_prompt != keywords_extract_prompt
if should_update_prompt or prompt.custom_rerank_info != custom_related_information:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
prompt.keywords_extract_prompt = keywords_extract_prompt
prompt.update_yaml_file()

vector_search = vector_only_answer or graph_vector_answer
Expand All @@ -66,7 +68,7 @@ def rag_answer(
if vector_search:
rag.query_vector_index()
if graph_search:
rag.extract_keywords().keywords_to_vid().query_graphdb()
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().query_graphdb()
# TODO: add more user-defined search strategies
rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, custom_related_information)
rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt)
Expand Down Expand Up @@ -101,7 +103,10 @@ def create_rag_block():
graph_vector_out = gr.Textbox(label="Graph-Vector Answer", show_copy_button=True)

answer_prompt_input = gr.Textbox(
value=prompt.answer_prompt, label="Custom Prompt", show_copy_button=True, lines=7
value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7
)
keywords_extract_prompt_input = gr.Textbox(
value=prompt.keywords_extract_prompt, label="Keywords Extraction Prompt", show_copy_button=True, lines=7
)
with gr.Column(scale=1):
with gr.Row():
Expand Down Expand Up @@ -134,7 +139,7 @@ def toggle_slider(enable):
)
custom_related_information = gr.Text(
prompt.custom_rerank_info,
label="Custom related information(Optional)",
label="Query related information(Optional)",
)
btn = gr.Button("Answer Question", variant="primary")

Expand All @@ -151,6 +156,7 @@ def toggle_slider(enable):
near_neighbor_first,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input
],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
)
Expand Down Expand Up @@ -209,6 +215,7 @@ def several_rag_answer(
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
keywords_extract_prompt: str,
progress=gr.Progress(track_tqdm=True),
answer_max_line_count: int = 1,
):
Expand All @@ -227,6 +234,7 @@ def several_rag_answer(
near_neighbor_first,
custom_related_information,
answer_prompt,
keywords_extract_prompt,
)
df.at[index, "Basic LLM Answer"] = basic_llm_answer
df.at[index, "Vector-only Answer"] = vector_only_answer
Expand Down Expand Up @@ -259,10 +267,11 @@ def several_rag_answer(
near_neighbor_first,
custom_related_information,
answer_prompt_input,
keywords_extract_prompt_input,
answer_max_line_count,
],
outputs=[qa_dataframe, gr.File(label="Download Answered File", min_width=40)],
)
questions_file.change(read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count])
answer_max_line_count.change(change_showing_excel, answer_max_line_count, qa_dataframe)
return inp, answer_prompt_input
return inp, answer_prompt_input, keywords_extract_prompt_input
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class NLTKHelper:
"chinese": None,
}

def stopwords(self, lang: str = "english") -> List[str]:
def stopwords(self, lang: str = "chinese") -> List[str]:
"""Get stopwords."""
nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
if self._stopwords.get(lang) is None:
Expand Down
3 changes: 0 additions & 3 deletions hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def extract_keywords(
max_keywords: int = 5,
language: str = "english",
extract_template: Optional[str] = None,
expand_template: Optional[str] = None,
):
"""
Add a keyword extraction operator to the pipeline.
Expand All @@ -78,7 +77,6 @@ def extract_keywords(
:param max_keywords: Maximum number of keywords to extract.
:param language: Language of the text.
:param extract_template: Template for keyword extraction.
:param expand_template: Template for keyword expansion.
:return: Self-instance for chaining.
"""
self._operators.append(
Expand All @@ -87,7 +85,6 @@ def extract_keywords(
max_keywords=max_keywords,
language=language,
extract_template=extract_template,
expand_template=expand_template,
)
)
return self
Expand Down
87 changes: 30 additions & 57 deletions hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,32 @@


import re
import time
from typing import Set, Dict, Any, Optional

from hugegraph_llm.config import prompt
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
from hugegraph_llm.utils.log import log

KEYWORDS_EXTRACT_TPL = """Extract {max_keywords} keywords from the text:
{question}
1. Keywords can't contain meaningless/broad words(e.g action/relation/thing), must represent certain entities,
2. Better to extract subject/verb/object and don't extract particles, don't extend to synonyms/general categories.
Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'
"""

KEYWORDS_EXPAND_TPL = """Generate synonyms or possible form of keywords up to {max_keywords} in total,
considering possible cases of capitalization, pluralization, common expressions, etc.
Provide all synonyms of keywords in comma-separated format: 'SYNONYMS: <keywords>'
Note, result should be in one-line with only one 'SYNONYMS: ' prefix
----
KEYWORDS: {question}
----"""
KEYWORDS_EXTRACT_TPL = prompt.keywords_extract_prompt


class KeywordExtract:
def __init__(
self,
text: Optional[str] = None,
llm: Optional[BaseLLM] = None,
max_keywords: int = 5,
extract_template: Optional[str] = None,
expand_template: Optional[str] = None,
language: str = "english",
self,
text: Optional[str] = None,
llm: Optional[BaseLLM] = None,
max_keywords: int = 5,
extract_template: Optional[str] = None,
language: str = "english",
):
self._llm = llm
self._query = text
self._language = language.lower()
self._max_keywords = max_keywords
self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL
self._expand_template = expand_template or KEYWORDS_EXPAND_TPL

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._query is None:
Expand All @@ -69,61 +55,48 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
self._llm = LLMs().get_extract_llm()
assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."

if isinstance(context.get("language"), str):
self._language = context["language"].lower()
else:
context["language"] = self._language

if isinstance(context.get("max_keywords"), int):
self._max_keywords = context["max_keywords"]
self._language = context.get("language", self._language).lower()
self._max_keywords = context.get("max_keywords", self._max_keywords)

prompt = self._extract_template.format(question=self._query, max_keywords=self._max_keywords)
prompt = f"{self._extract_template.format(question=self._query, max_keywords=self._max_keywords)}"
start_time = time.perf_counter()
response = self._llm.generate(prompt=prompt)
end_time = time.perf_counter()
log.debug("Keyword extraction time: %.2f seconds", end_time - start_time)

keywords = self._extract_keywords_from_response(
response=response, lowercase=False, start_token="KEYWORDS:"
)
keywords.union(self._expand_synonyms(keywords=keywords))
keywords = {k.replace("'", "") for k in keywords}
context["keywords"] = list(keywords)
log.info("User Query: %s\nKeywords: %s", self._query, context["keywords"])

# extracting keywords & expanding synonyms increase the call count by 2
context["call_count"] = context.get("call_count", 0) + 2
# extracting keywords & expanding synonyms increase the call count by 1
context["call_count"] = context.get("call_count", 0) + 1
return context

def _expand_synonyms(self, keywords: Set[str]) -> Set[str]:
prompt = self._expand_template.format(question=str(keywords), max_keywords=self._max_keywords)
response = self._llm.generate(prompt=prompt)
keywords = self._extract_keywords_from_response(
response=response, lowercase=False, start_token="SYNONYMS:"
)
return keywords

def _extract_keywords_from_response(
self,
response: str,
lowercase: bool = True,
start_token: str = "",
self,
response: str,
lowercase: bool = True,
start_token: str = "",
) -> Set[str]:
keywords = []
# use re.escape(start_token) if start_token contains special chars like */&/^ etc.
matches = re.findall(rf'{start_token}[^\n]+\n?', response)

for match in matches:
match = match[len(start_token):]
for k in re.split(r"[,,]+", match):
k = k.strip()
if len(k) > 1:
if lowercase:
keywords.append(k.lower())
else:
keywords.append(k)
match = match[len(start_token):].strip()
keywords.extend(
k.lower() if lowercase else k
for k in re.split(r"[,,]+", match)
if len(k.strip()) > 1
)

# if the keyword consists of multiple words, split into sub-words (removing stopwords)
results = set()
results = set(keywords)
for token in keywords:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language)})
results.update(w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language))
return results

0 comments on commit ef56263

Please sign in to comment.