Skip to content

Commit

Permalink
Add citation
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 committed Dec 9, 2024
1 parent c7d6801 commit 0fc8751
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 59 deletions.
1 change: 1 addition & 0 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class RagQuery(BaseModel):
session_id: str | None = None
vector_db: VectorDbConfig | None = None
stream: bool | None = False
citation: bool | None = False
with_intent: bool | None = False
index_name: str | None = None

Expand Down
12 changes: 11 additions & 1 deletion src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def query(
text: str,
with_history: bool = False,
stream: bool = False,
citation: bool = False,
with_intent: bool = False,
index_name: str = None,
):
Expand All @@ -221,9 +222,11 @@ def query(
question=text,
session_id=session_id,
stream=stream,
citation=citation,
with_intent=with_intent,
index_name=index_name,
)
print(q)
r = requests.post(self.query_url, json=q, stream=True)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)
Expand All @@ -248,10 +251,17 @@ def query_search(
self,
text: str,
with_history: bool = False,
citation: bool = False,
stream: bool = False,
):
session_id = self.session_id if with_history else None
q = dict(question=text, session_id=session_id, stream=stream, with_intent=False)
q = dict(
question=text,
session_id=session_id,
stream=stream,
with_intent=False,
citation=citation,
)
r = requests.post(self.search_url, json=q, stream=True)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)
Expand Down
40 changes: 36 additions & 4 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def respond(input_elements: List[Any]):
chatbot = update_dict["chatbot"]
is_streaming = update_dict["is_streaming"]
index_name = update_dict["chat_index"]
citation = update_dict["citation"]

if chatbot is not None:
chatbot.append((msg, ""))
Expand All @@ -51,13 +52,17 @@ def respond(input_elements: List[Any]):

elif query_type == "RAG (Search Web)":
response_gen = rag_client.query_search(
msg, with_history=update_dict["include_history"], stream=is_streaming
msg,
with_history=update_dict["include_history"],
stream=is_streaming,
citation=citation,
)
else:
response_gen = rag_client.query(
msg,
with_history=update_dict["include_history"],
stream=is_streaming,
citation=citation,
index_name=index_name,
)

Expand Down Expand Up @@ -94,6 +99,12 @@ def create_chat_tab() -> Dict[str, Any]:
elem_id="is_streaming",
value=True,
)
citation = gr.Checkbox(
label="Citation",
info="Need Citation",
elem_id="citation",
value=True,
)
need_image = gr.Checkbox(
label="Display Image",
info="Inference with multi-modal LLM.",
Expand Down Expand Up @@ -281,19 +292,35 @@ def change_retrieval_mode(retrieval_mode):
search_args = {search_api_key, search_count, search_lang}

with gr.Column(visible=True) as lc_col:
with gr.Tab("LLM Prompt"):
with gr.Tab("Prompt"):
text_qa_template = gr.Textbox(
label="Prompt Template",
value="",
elem_id="text_qa_template",
lines=10,
interactive=True,
)
with gr.Tab("MultiModal LLM Prompt"):
citation_text_qa_template = gr.Textbox(
label="Citation Prompt Template",
value="",
elem_id="citation_text_qa_template",
lines=10,
interactive=True,
)
with gr.Tab("MultiModal Prompt"):
multimodal_qa_template = gr.Textbox(
label="Multi-modal LLM Prompt Template",
label="Multi-modal Prompt Template",
value="",
elem_id="multimodal_qa_template",
lines=12,
interactive=True,
)
citation_multimodal_qa_template = gr.Textbox(
label="Citation Multi-modal Prompt Template",
value="",
elem_id="citation_multimodal_qa_template",
lines=12,
interactive=True,
)

cur_tokens = gr.Textbox(
Expand Down Expand Up @@ -367,10 +394,13 @@ def change_query_radio(query_type):
{
text_qa_template,
multimodal_qa_template,
citation_text_qa_template,
citation_multimodal_qa_template,
question,
query_type,
chatbot,
is_streaming,
citation,
need_image,
include_history,
chat_index,
Expand Down Expand Up @@ -419,6 +449,8 @@ def change_query_radio(query_type):
similarity_threshold.elem_id: similarity_threshold,
reranker_similarity_threshold.elem_id: reranker_similarity_threshold,
multimodal_qa_template.elem_id: multimodal_qa_template,
citation_multimodal_qa_template.elem_id: citation_multimodal_qa_template,
citation_text_qa_template.elem_id: citation_text_qa_template,
text_qa_template.elem_id: text_qa_template,
search_lang.elem_id: search_lang,
search_api_key.elem_id: search_api_key,
Expand Down
2 changes: 0 additions & 2 deletions src/pai_rag/app/web/ui_constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from pai_rag.integrations.synthesizer.pai_synthesizer import (
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
)
from pai_rag.utils.prompt_template import DEFAULT_TEXT_QA_PROMPT_TMPL

DEFAULT_TEXT_QA_PROMPT_TMPL = DEFAULT_TEXT_QA_PROMPT_TMPL
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL = DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL

DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
Expand Down
26 changes: 22 additions & 4 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pai_rag.app.web.ui_constants import (
LLM_MODEL_KEY_DICT,
MLLM_MODEL_KEY_DICT,
DEFAULT_TEXT_QA_PROMPT_TMPL,
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
)
import pandas as pd
import os
Expand Down Expand Up @@ -115,8 +113,10 @@ class ViewModel(BaseModel):

synthesizer_type: str = None

text_qa_template: str = DEFAULT_TEXT_QA_PROMPT_TMPL
multimodal_qa_template: str = DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL
text_qa_template: str = None
multimodal_qa_template: str = None
citation_text_qa_template: str = None
citation_multimodal_qa_template: str = None

# agent
agent_api_definition: str = None # API tool definition
Expand Down Expand Up @@ -199,6 +199,12 @@ def from_app_config(config: RagConfig):

view_model.text_qa_template = config.synthesizer.text_qa_template
view_model.multimodal_qa_template = config.synthesizer.multimodal_qa_template
view_model.citation_text_qa_template = (
config.synthesizer.citation_text_qa_template
)
view_model.citation_multimodal_qa_template = (
config.synthesizer.citation_multimodal_qa_template
)

view_model.search_api_key = config.search.search_api_key or os.environ.get(
"BING_SEARCH_KEY"
Expand Down Expand Up @@ -340,6 +346,12 @@ def to_app_config(self):
config["synthesizer"]["use_multimodal_llm"] = self.use_mllm
config["synthesizer"]["text_qa_template"] = self.text_qa_template
config["synthesizer"]["multimodal_qa_template"] = self.multimodal_qa_template
config["synthesizer"][
"citation_text_qa_template"
] = self.citation_text_qa_template
config["synthesizer"][
"citation_multimodal_qa_template"
] = self.citation_multimodal_qa_template

config["search"]["search_api_key"] = self.search_api_key or os.environ.get(
"BING_SEARCH_KEY"
Expand Down Expand Up @@ -518,6 +530,12 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:

settings["text_qa_template"] = {"value": self.text_qa_template}
settings["multimodal_qa_template"] = {"value": self.multimodal_qa_template}
settings["citation_text_qa_template"] = {
"value": self.citation_text_qa_template
}
settings["citation_multimodal_qa_template"] = {
"value": self.citation_multimodal_qa_template
}

# search
settings["search_api_key"] = {"value": self.search_api_key}
Expand Down
6 changes: 0 additions & 6 deletions src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,3 @@ search_api_key = ""

[rag.synthesizer]
type = "SimpleSummarize"
text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n"

[rag.trace]
type = "pai_trace"
endpoint = "http://tracing-analysis-dc-hz.aliyuncs.com:8090"
token = ""
10 changes: 6 additions & 4 deletions src/pai_rag/core/models/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from enum import Enum
from typing import List, Literal
from pydantic import BaseModel
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pai_rag.integrations.synthesizer.pai_synthesizer import (
DEFAULT_TEXT_QA_TMPL,
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
CITATION_TEXT_QA_TMPL,
CITATION_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
)


Expand Down Expand Up @@ -46,8 +46,10 @@ class SearchWebConfig(BaseModel):

class SynthesizerConfig(BaseModel):
use_multimodal_llm: bool = False
text_qa_template: str = DEFAULT_TEXT_QA_PROMPT_SEL
text_qa_template: str = DEFAULT_TEXT_QA_TMPL
citation_text_qa_template: str = CITATION_TEXT_QA_TMPL
multimodal_qa_template: str = DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL
citation_multimodal_qa_template: str = CITATION_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL


class TraceType(str, Enum):
Expand Down
4 changes: 3 additions & 1 deletion src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ async def aquery(
elif intent != Intents.RAG:
return ValueError(f"Invalid intent {intent}")

query_bundle = PaiQueryBundle(query_str=new_question, stream=query.stream)
query_bundle = PaiQueryBundle(
query_str=new_question, stream=query.stream, citation=query.citation
)
chat_store.add_message(
session_id, ChatMessage(role=MessageRole.USER, content=query.question)
)
Expand Down
6 changes: 6 additions & 0 deletions src/pai_rag/core/rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ def resolve_synthesizer(config: RagConfig) -> PaiSynthesizer:
multimodal_qa_template=PromptTemplate(
template=config.synthesizer.multimodal_qa_template
),
citation_text_qa_template=PromptTemplate(
template=config.synthesizer.citation_text_qa_template
),
citation_multimodal_qa_template=PromptTemplate(
template=config.synthesizer.citation_multimodal_qa_template
),
)
return synthesizer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class NodeParserConfig(BaseModel):
"total_pages",
"source",
"row_number",
"image_info_list",
]


Expand Down
Loading

0 comments on commit 0fc8751

Please sign in to comment.