From 1705a9f33d77eaf304fa021319988bdec204a4b9 Mon Sep 17 00:00:00 2001
From: Yue Fei <59813791+moria97@users.noreply.github.com>
Date: Tue, 4 Jun 2024 20:50:03 +0800
Subject: [PATCH] feat: enable Gradio page to share state between tabs (#48)
* Remove entrypoint
* Update ui
* Revert dockerfile change
* Delete debug clauses
* Fix embed dim value
* Update view model
---
src/pai_rag/app/app.py | 3 +-
src/pai_rag/app/web/element_manager.py | 48 ++
src/pai_rag/app/web/prompts.py | 11 -
src/pai_rag/app/web/tabs/chat_tab.py | 249 ++++++++
src/pai_rag/app/web/tabs/settings_tab.py | 175 ++++++
src/pai_rag/app/web/tabs/upload_tab.py | 84 +++
.../app/web/{ => tabs}/vector_db_panel.py | 110 ++--
src/pai_rag/app/web/ui.py | 549 ------------------
src/pai_rag/app/web/ui_constants.py | 55 ++
src/pai_rag/app/web/utils.py | 5 +
src/pai_rag/app/web/view_model.py | 81 ++-
src/pai_rag/app/web/webui.py | 48 ++
12 files changed, 803 insertions(+), 615 deletions(-)
create mode 100644 src/pai_rag/app/web/element_manager.py
delete mode 100644 src/pai_rag/app/web/prompts.py
create mode 100644 src/pai_rag/app/web/tabs/chat_tab.py
create mode 100644 src/pai_rag/app/web/tabs/settings_tab.py
create mode 100644 src/pai_rag/app/web/tabs/upload_tab.py
rename src/pai_rag/app/web/{ => tabs}/vector_db_panel.py (79%)
delete mode 100644 src/pai_rag/app/web/ui.py
create mode 100644 src/pai_rag/app/web/ui_constants.py
create mode 100644 src/pai_rag/app/web/utils.py
create mode 100644 src/pai_rag/app/web/webui.py
diff --git a/src/pai_rag/app/app.py b/src/pai_rag/app/app.py
index e4c418ea..4db4797e 100644
--- a/src/pai_rag/app/app.py
+++ b/src/pai_rag/app/app.py
@@ -3,10 +3,9 @@
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api import query
from pai_rag.app.api.middleware import init_middleware
-from pai_rag.app.web.ui import create_ui
+from pai_rag.app.web.webui import create_ui
from pai_rag.app.web.rag_client import rag_client
-# UI_PATH = "/ui"
UI_PATH = ""
diff --git a/src/pai_rag/app/web/element_manager.py b/src/pai_rag/app/web/element_manager.py
new file mode 100644
index 00000000..d374bbb5
--- /dev/null
+++ b/src/pai_rag/app/web/element_manager.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING, Dict, Generator, List, Tuple
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+class ElementManager:
+ def __init__(self) -> None:
+ self._id_to_elem: Dict[str, "Component"] = {}
+ self._elem_to_id: Dict["Component", str] = {}
+
+ def add_elems(self, elem_dict: Dict[str, "Component"]) -> None:
+ r"""
+ Adds elements to manager.
+ """
+ for elem_id, elem in elem_dict.items():
+ self._id_to_elem[elem_id] = elem
+ self._elem_to_id[elem] = elem_id
+
+ def get_elem_list(self) -> List["Component"]:
+ r"""
+ Returns the list of all elements.
+ """
+ return list(self._id_to_elem.values())
+
+ def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
+ r"""
+ Returns an iterator over all elements with their names.
+ """
+ for elem_id, elem in self._id_to_elem.items():
+ yield elem_id.split(".")[-1], elem
+
+ def get_elem_by_id(self, elem_id: str) -> "Component":
+ r"""
+ Gets element by id.
+
+ Example: top.lang, train.dataset
+ """
+ return self._id_to_elem[elem_id]
+
+ def get_id_by_elem(self, elem: "Component") -> str:
+ r"""
+ Gets id by element.
+ """
+ return self._elem_to_id[elem]
+
+
+elem_manager = ElementManager()
diff --git a/src/pai_rag/app/web/prompts.py b/src/pai_rag/app/web/prompts.py
deleted file mode 100644
index 2b6be084..00000000
--- a/src/pai_rag/app/web/prompts.py
+++ /dev/null
@@ -1,11 +0,0 @@
-SIMPLE_PROMPTS = "参考内容如下:\n{context_str}\n作为个人知识答疑助手,请根据上述参考内容回答下面问题,答案中不允许包含编造内容。\n用户问题:\n{query_str}"
-GENERAL_PROMPTS = '基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context_str}\n=====\n用户问题:\n{query_str}'
-EXTRACT_URL_PROMPTS = "你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"
-ACCURATE_CONTENT_PROMPTS = "你是一位知识小助手,请根据下面我提供的知识库中相关知识,对我提出的若干问题进行回答,同时回答的内容需满足我所提的要求! \n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"
-
-PROMPT_MAP = {
- SIMPLE_PROMPTS: "Simple",
- GENERAL_PROMPTS: "General",
- EXTRACT_URL_PROMPTS: "Extract URL",
- ACCURATE_CONTENT_PROMPTS: "Accurate Content",
-}
diff --git a/src/pai_rag/app/web/tabs/chat_tab.py b/src/pai_rag/app/web/tabs/chat_tab.py
new file mode 100644
index 00000000..5ada0115
--- /dev/null
+++ b/src/pai_rag/app/web/tabs/chat_tab.py
@@ -0,0 +1,249 @@
+from typing import Dict, Any, List
+import gradio as gr
+from pai_rag.app.web.rag_client import rag_client
+from pai_rag.app.web.view_model import view_model
+from pai_rag.app.web.ui_constants import (
+ SIMPLE_PROMPTS,
+ GENERAL_PROMPTS,
+ EXTRACT_URL_PROMPTS,
+ ACCURATE_CONTENT_PROMPTS,
+)
+
+
+current_session_id = None
+
+
+def clear_history(chatbot):
+ chatbot = []
+ global current_session_id
+ current_session_id = None
+ return chatbot, 0
+
+
+def respond(input_elements: List[Any]):
+ global current_session_id
+
+ update_dict = {}
+ for element, value in input_elements.items():
+ update_dict[element.elem_id] = value
+
+ # empty input.
+ if not update_dict["question"]:
+ return "", update_dict["chatbot"], 0
+
+ view_model.update(update_dict)
+ new_config = view_model.to_app_config()
+ rag_client.reload_config(new_config)
+
+ query_type = update_dict["query_type"]
+ msg = update_dict["question"]
+ chatbot = update_dict["chatbot"]
+
+ if query_type == "LLM":
+ response = rag_client.query_llm(
+ msg,
+ session_id=current_session_id,
+ )
+
+ elif query_type == "Retrieval":
+ response = rag_client.query_vector(msg)
+ else:
+ response = rag_client.query(msg, session_id=current_session_id)
+ if update_dict["include_history"]:
+ current_session_id = response.session_id
+ else:
+ current_session_id = None
+ chatbot.append((msg, response.answer))
+ return "", chatbot, 0
+
+
+def create_chat_tab() -> Dict[str, Any]:
+ with gr.Row():
+ with gr.Column(scale=2):
+ query_type = gr.Radio(
+ ["Retrieval", "LLM", "RAG (Retrieval + LLM)"],
+ label="\N{fire} Which query do you want to use?",
+ elem_id="query_type",
+ value="RAG (Retrieval + LLM)",
+ )
+
+ with gr.Column(visible=True) as vs_col:
+ vec_model_argument = gr.Accordion("Parameters of Vector Retrieval")
+
+ with vec_model_argument:
+ similarity_top_k = gr.Slider(
+ minimum=0,
+ maximum=100,
+ step=1,
+ elem_id="similarity_top_k",
+ label="Top K (choose between 0 and 100)",
+ )
+ # similarity_cutoff = gr.Slider(minimum=0, maximum=1, step=0.01,elem_id="similarity_cutoff",value=view_model.similarity_cutoff, label="Similarity Distance Threshold (The more similar the vectors, the smaller the value.)")
+ rerank_model = gr.Radio(
+ [
+ "no-reranker",
+ "bge-reranker-base",
+ "bge-reranker-large",
+ "llm-reranker",
+ ],
+ label="Re-Rank Model (Note: It will take a long time to load the model when using it for the first time.)",
+ elem_id="rerank_model",
+ )
+ retrieval_mode = gr.Radio(
+ ["Embedding Only", "Keyword Ensembled", "Keyword Only"],
+ label="Retrieval Mode",
+ elem_id="retrieval_mode",
+ )
+ vec_args = {
+ similarity_top_k,
+ # similarity_cutoff,
+ rerank_model,
+ retrieval_mode,
+ }
+ with gr.Column(visible=True) as llm_col:
+ model_argument = gr.Accordion("Inference Parameters of LLM")
+ with model_argument:
+ include_history = gr.Checkbox(
+ label="Chat history",
+ info="Query with chat history.",
+ elem_id="include_history",
+ )
+ llm_topk = gr.Slider(
+ minimum=0,
+ maximum=100,
+ step=1,
+ value=30,
+ elem_id="llm_topk",
+ label="Top K (choose between 0 and 100)",
+ )
+ llm_topp = gr.Slider(
+ minimum=0,
+ maximum=1,
+ step=0.01,
+ value=0.8,
+ elem_id="llm_topp",
+ label="Top P (choose between 0 and 1)",
+ )
+ llm_temp = gr.Slider(
+ minimum=0,
+ maximum=1,
+ step=0.01,
+ value=0.7,
+ elem_id="llm_temp",
+ label="Temperature (choose between 0 and 1)",
+ )
+ llm_args = {llm_topk, llm_topp, llm_temp, include_history}
+
+ with gr.Column(visible=True) as lc_col:
+ prm_type = gr.Radio(
+ [
+ "Simple",
+ "General",
+ "Extract URL",
+ "Accurate Content",
+ "Custom",
+ ],
+ label="\N{rocket} Please choose the prompt template type",
+ elem_id="prm_type",
+ )
+ text_qa_template = gr.Textbox(
+ label="prompt template",
+ placeholder="",
+ elem_id="text_qa_template",
+ lines=4,
+ )
+
+ def change_prompt_template(prm_type):
+ if prm_type == "Simple":
+ return {
+ text_qa_template: gr.update(
+ value=SIMPLE_PROMPTS, interactive=False
+ )
+ }
+ elif prm_type == "General":
+ return {
+ text_qa_template: gr.update(
+ value=GENERAL_PROMPTS, interactive=False
+ )
+ }
+ elif prm_type == "Extract URL":
+ return {
+ text_qa_template: gr.update(
+ value=EXTRACT_URL_PROMPTS, interactive=False
+ )
+ }
+ elif prm_type == "Accurate Content":
+ return {
+ text_qa_template: gr.update(
+ value=ACCURATE_CONTENT_PROMPTS,
+ interactive=False,
+ )
+ }
+ else:
+ return {text_qa_template: gr.update(value="", interactive=True)}
+
+ prm_type.change(
+ fn=change_prompt_template,
+ inputs=prm_type,
+ outputs=[text_qa_template],
+ )
+
+ cur_tokens = gr.Textbox(
+ label="\N{fire} Current total count of tokens", visible=False
+ )
+
+ def change_query_radio(query_type):
+ global current_session_id
+ current_session_id = None
+ if query_type == "Retrieval":
+ return {
+ vs_col: gr.update(visible=True),
+ llm_col: gr.update(visible=False),
+ lc_col: gr.update(visible=False),
+ }
+ elif query_type == "LLM":
+ return {
+ vs_col: gr.update(visible=False),
+ llm_col: gr.update(visible=True),
+ lc_col: gr.update(visible=False),
+ }
+ elif query_type == "RAG (Retrieval + LLM)":
+ return {
+ vs_col: gr.update(visible=True),
+ llm_col: gr.update(visible=True),
+ lc_col: gr.update(visible=True),
+ }
+
+ query_type.change(
+ fn=change_query_radio,
+ inputs=query_type,
+ outputs=[vs_col, llm_col, lc_col],
+ )
+
+ with gr.Column(scale=8):
+ chatbot = gr.Chatbot(height=500, elem_id="chatbot")
+ question = gr.Textbox(label="Enter your question.", elem_id="question")
+ with gr.Row():
+ submitBtn = gr.Button("Submit", variant="primary")
+ clearBtn = gr.Button("Clear History", variant="secondary")
+
+ chat_args = (
+ {text_qa_template, question, query_type, chatbot}
+ .union(vec_args)
+ .union(llm_args)
+ )
+
+ submitBtn.click(
+ respond,
+ chat_args,
+ [question, chatbot, cur_tokens],
+ api_name="respond",
+ )
+ clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens])
+ return {
+ similarity_top_k.elem_id: similarity_top_k,
+ rerank_model.elem_id: rerank_model,
+ retrieval_mode.elem_id: retrieval_mode,
+ prm_type.elem_id: prm_type,
+ text_qa_template.elem_id: text_qa_template,
+ }
diff --git a/src/pai_rag/app/web/tabs/settings_tab.py b/src/pai_rag/app/web/tabs/settings_tab.py
new file mode 100644
index 00000000..f7614903
--- /dev/null
+++ b/src/pai_rag/app/web/tabs/settings_tab.py
@@ -0,0 +1,175 @@
+from typing import Dict, Any, List
+import gradio as gr
+import datetime
+import traceback
+from pai_rag.app.web.ui_constants import (
+ EMBEDDING_API_KEY_DICT,
+ DEFAULT_EMBED_SIZE,
+ EMBEDDING_DIM_DICT,
+ LLM_MODEL_KEY_DICT,
+)
+from pai_rag.app.web.rag_client import rag_client
+from pai_rag.app.web.view_model import view_model
+from pai_rag.app.web.utils import components_to_dict
+from pai_rag.app.web.tabs.vector_db_panel import create_vector_db_panel
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def connect_vector_db(input_elements: List[Any]):
+ try:
+ update_dict = {}
+ for element, value in input_elements.items():
+ update_dict[element.elem_id] = value
+
+ view_model.update(update_dict)
+ new_config = view_model.to_app_config()
+ rag_client.reload_config(new_config)
+ return f"[{datetime.datetime.now()}] Connect vector db success!"
+ except Exception as ex:
+ logger.critical(f"[Critical] Connect failed. {traceback.format_exc()}")
+ return f"Connect failed. Please check: {ex}"
+
+
+def create_setting_tab() -> Dict[str, Any]:
+ components = []
+ with gr.Row():
+ with gr.Column():
+ with gr.Column():
+ _ = gr.Markdown(value="**Please choose your embedding model.**")
+ embed_source = gr.Dropdown(
+ EMBEDDING_API_KEY_DICT.keys(),
+ label="Embedding Type",
+ elem_id="embed_source",
+ )
+ embed_model = gr.Dropdown(
+ EMBEDDING_DIM_DICT.keys(),
+ label="Embedding Model Name",
+ elem_id="embed_model",
+ visible=False,
+ )
+ embed_dim = gr.Textbox(
+ label="Embedding Dimension",
+ elem_id="embed_dim",
+ )
+
+ def change_emb_source(source):
+ view_model.embed_source = source
+ return {
+ embed_model: gr.update(visible=(source == "HuggingFace")),
+ embed_dim: EMBEDDING_DIM_DICT.get(
+ view_model.embed_model, DEFAULT_EMBED_SIZE
+ )
+ if source == "HuggingFace"
+ else DEFAULT_EMBED_SIZE,
+ }
+
+ def change_emb_model(model):
+ view_model.embed_model = model
+ return {
+ embed_dim: EMBEDDING_DIM_DICT.get(
+ view_model.embed_model, DEFAULT_EMBED_SIZE
+ )
+ if view_model.embed_source == "HuggingFace"
+ else DEFAULT_EMBED_SIZE,
+ }
+
+ embed_source.change(
+ fn=change_emb_source,
+ inputs=embed_source,
+ outputs=[embed_model, embed_dim],
+ )
+ embed_model.change(
+ fn=change_emb_model,
+ inputs=embed_model,
+ outputs=[embed_dim],
+ )
+ components.extend([embed_source, embed_dim, embed_model])
+
+ with gr.Column():
+ _ = gr.Markdown(value="**Please set your LLM.**")
+ llm = gr.Dropdown(
+ ["PaiEas", "OpenAI", "DashScope"],
+ label="LLM Model Source",
+ elem_id="llm",
+ )
+ with gr.Column(visible=(view_model.llm == "PaiEas")) as eas_col:
+ llm_eas_url = gr.Textbox(
+ label="EAS Url",
+ elem_id="llm_eas_url",
+ )
+ llm_eas_token = gr.Textbox(
+ label="EAS Token",
+ elem_id="llm_eas_token",
+ )
+ llm_eas_model_name = gr.Textbox(
+ label="EAS Model name",
+ elem_id="llm_eas_model_name",
+ )
+ with gr.Column(visible=(view_model.llm != "PaiEas")) as api_llm_col:
+ llm_api_model_name = gr.Dropdown(
+ label="LLM Model Name",
+ elem_id="llm_api_model_name",
+ )
+
+ components.extend(
+ [
+ llm,
+ llm_eas_url,
+ llm_eas_token,
+ llm_eas_model_name,
+ llm_api_model_name,
+ ]
+ )
+
+ def change_llm(value):
+ view_model.llm = value
+ eas_visible = value == "PaiEas"
+ api_visible = value != "PaiEas"
+ model_options = LLM_MODEL_KEY_DICT.get(value, [])
+ cur_model = model_options[0] if model_options else ""
+ return {
+ eas_col: gr.update(visible=eas_visible),
+ api_llm_col: gr.update(visible=api_visible),
+ llm_api_model_name: gr.update(
+ choices=model_options, value=cur_model
+ ),
+ }
+
+ llm.change(
+ fn=change_llm,
+ inputs=llm,
+ outputs=[eas_col, api_llm_col, llm_api_model_name],
+ )
+ """
+ with gr.Column():
+ _ = gr.Markdown(
+ value="**(Optional) Please upload your config file.**"
+ )
+ config_file = gr.File(
+ value=view_model.config_file,
+ label="Upload a local config json file",
+ file_types=[".json"],
+ file_count="single",
+ interactive=True,
+ )
+ cfg_btn = gr.Button("Parse Config", variant="primary")
+ """
+ vector_db_elems = create_vector_db_panel(
+ input_elements={
+ llm,
+ llm_eas_url,
+ llm_eas_token,
+ llm_eas_model_name,
+ embed_source,
+ embed_model,
+ embed_dim,
+ llm_api_model_name,
+ },
+ connect_vector_func=connect_vector_db,
+ )
+
+ elems = components_to_dict(components)
+ elems.update(vector_db_elems)
+ return elems
diff --git a/src/pai_rag/app/web/tabs/upload_tab.py b/src/pai_rag/app/web/tabs/upload_tab.py
new file mode 100644
index 00000000..8dba2c6b
--- /dev/null
+++ b/src/pai_rag/app/web/tabs/upload_tab.py
@@ -0,0 +1,84 @@
+import os
+from typing import Dict, Any
+import gradio as gr
+from pai_rag.app.web.rag_client import rag_client
+from pai_rag.app.web.view_model import view_model
+
+
+def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction):
+ view_model.chunk_size = chunk_size
+ view_model.chunk_overlap = chunk_overlap
+ new_config = view_model.to_app_config()
+ rag_client.reload_config(new_config)
+
+ if not upload_files:
+ return "No file selected. Please choose at least one file."
+
+ for file in upload_files:
+ file_dir = os.path.dirname(file.name)
+ rag_client.add_knowledge(file_dir, enable_qa_extraction)
+ return (
+ "Upload "
+ + str(len(upload_files))
+ + " files Success! \n \n Relevant content has been added to the vector store, you can now start chatting and asking questions."
+ )
+
+
+def create_upload_tab() -> Dict[str, Any]:
+ with gr.Row():
+ with gr.Column(scale=2):
+ chunk_size = gr.Textbox(
+ label="\N{rocket} Chunk Size (The size of the chunks into which a document is divided)",
+ elem_id="chunk_size",
+ )
+
+ chunk_overlap = gr.Textbox(
+ label="\N{fire} Chunk Overlap (The portion of adjacent document chunks that overlap with each other)",
+ elem_id="chunk_overlap",
+ )
+ enable_qa_extraction = gr.Checkbox(
+ label="Yes",
+ info="Process with QA Extraction Model",
+ elem_id="enable_qa_extraction",
+ )
+ with gr.Column(scale=8):
+ with gr.Tab("Files"):
+ upload_file = gr.File(
+ label="Upload a knowledge file.", file_count="multiple"
+ )
+ upload_file_btn = gr.Button("Upload", variant="primary")
+ upload_file_state = gr.Textbox(label="Upload State")
+ with gr.Tab("Directory"):
+ upload_file_dir = gr.File(
+ label="Upload a knowledge directory.",
+ file_count="directory",
+ )
+ upload_dir_btn = gr.Button("Upload", variant="primary")
+ upload_dir_state = gr.Textbox(label="Upload State")
+ upload_file_btn.click(
+ fn=upload_knowledge,
+ inputs=[
+ upload_file,
+ chunk_size,
+ chunk_overlap,
+ enable_qa_extraction,
+ ],
+ outputs=upload_file_state,
+ api_name="upload_knowledge",
+ )
+ upload_dir_btn.click(
+ fn=upload_knowledge,
+ inputs=[
+ upload_file_dir,
+ chunk_size,
+ chunk_overlap,
+ enable_qa_extraction,
+ ],
+ outputs=upload_dir_state,
+ api_name="upload_knowledge_dir",
+ )
+ return {
+ chunk_size.elem_id: chunk_size,
+ chunk_overlap.elem_id: chunk_overlap,
+ enable_qa_extraction.elem_id: enable_qa_extraction,
+ }
diff --git a/src/pai_rag/app/web/vector_db_panel.py b/src/pai_rag/app/web/tabs/vector_db_panel.py
similarity index 79%
rename from src/pai_rag/app/web/vector_db_panel.py
rename to src/pai_rag/app/web/tabs/vector_db_panel.py
index 8b7db86e..fc66df8a 100644
--- a/src/pai_rag/app/web/vector_db_panel.py
+++ b/src/pai_rag/app/web/tabs/vector_db_panel.py
@@ -1,14 +1,15 @@
import gradio as gr
-from typing import Any, Set, Callable
-from pai_rag.app.web.view_model import ViewModel
+from typing import Any, Set, Callable, Dict
+from pai_rag.app.web.view_model import view_model
+from pai_rag.app.web.utils import components_to_dict
def create_vector_db_panel(
- view_model: ViewModel,
input_elements: Set[Any],
connect_vector_func: Callable[[Any], str],
-):
- with gr.Column() as panel:
+) -> Dict[str, Any]:
+ components = []
+ with gr.Column():
with gr.Column():
_ = gr.Markdown(
value=f"**Please check your Vector Store for {view_model.vectordb_type}.**"
@@ -16,7 +17,6 @@ def create_vector_db_panel(
vectordb_type = gr.Dropdown(
["Hologres", "Milvus", "ElasticSearch", "AnalyticDB", "FAISS"],
label="Which VectorStore do you want to use?",
- value=view_model.vectordb_type,
elem_id="vectordb_type",
)
# Adb
@@ -26,13 +26,11 @@ def create_vector_db_panel(
adb_ak = gr.Textbox(
label="access-key-id",
type="password",
- value=view_model.adb_ak,
elem_id="adb_ak",
)
adb_sk = gr.Textbox(
label="access-key-secret",
type="password",
- value=view_model.adb_sk,
elem_id="adb_sk",
)
adb_region_id = gr.Dropdown(
@@ -42,35 +40,28 @@ def create_vector_db_panel(
"cn-zhangjiakou",
"cn-huhehaote",
"cn-shanghai",
- " cn-shenzhen",
+ "cn-shenzhen",
"cn-chengdu",
],
label="RegionId",
- value=view_model.adb_region_id,
elem_id="adb_region_id",
)
adb_instance_id = gr.Textbox(
label="InstanceId",
- value=view_model.adb_instance_id,
elem_id="adb_instance_id",
)
- adb_account = gr.Textbox(
- label="Account", value=view_model.adb_account, elem_id="adb_account"
- )
+ adb_account = gr.Textbox(label="Account", elem_id="adb_account")
adb_account_password = gr.Textbox(
label="Password",
type="password",
- value=view_model.adb_account_password,
elem_id="adb_account_password",
)
adb_namespace = gr.Textbox(
label="Namespace",
- value=view_model.adb_namespace,
elem_id="adb_namespace",
)
adb_collection = gr.Textbox(
label="CollectionName",
- value=view_model.adb_collection,
elem_id="adb_collection",
)
@@ -101,36 +92,35 @@ def create_vector_db_panel(
) as holo_col:
hologres_host = gr.Textbox(
label="Host",
- value=view_model.hologres_host,
elem_id="hologres_host",
)
- hologres_user = gr.Textbox(
- label="User",
- value=view_model.hologres_user,
- elem_id="hologres_user",
+ hologres_port = gr.Textbox(
+ label="Port",
+ elem_id="hologres_port",
)
hologres_database = gr.Textbox(
label="Database",
- value=view_model.hologres_database,
elem_id="hologres_database",
)
+ hologres_user = gr.Textbox(
+ label="User",
+ elem_id="hologres_user",
+ )
hologres_password = gr.Textbox(
label="Password",
type="password",
- value=view_model.hologres_password,
elem_id="hologres_password",
)
hologres_table = gr.Textbox(
label="Table",
- value=view_model.hologres_table,
elem_id="hologres_table",
)
hologres_pre_delete = gr.Dropdown(
["True", "False"],
label="Pre Delete",
- value=view_model.hologres_pre_delete,
elem_id="hologres_pre_delete",
)
+
connect_btn_hologres = gr.Button("Connect Hologres", variant="primary")
con_state_hologres = gr.Textbox(label="Connection Info: ")
inputs_hologres = input_elements.union(
@@ -154,21 +144,15 @@ def create_vector_db_panel(
with gr.Column(
visible=(view_model.vectordb_type == "ElasticSearch")
) as es_col:
- es_url = gr.Textbox(
- label="ElasticSearch Url", value=view_model.es_url, elem_id="es_url"
- )
- es_index = gr.Textbox(
- label="Index Name", value=view_model.es_index, elem_id="es_index"
- )
- es_user = gr.Textbox(
- label="ES User", value=view_model.es_user, elem_id="es_user"
- )
+ es_url = gr.Textbox(label="ElasticSearch Url", elem_id="es_url")
+ es_index = gr.Textbox(label="Index Name", elem_id="es_index")
+ es_user = gr.Textbox(label="ES User", elem_id="es_user")
es_password = gr.Textbox(
label="ES password",
type="password",
- value=view_model.es_password,
elem_id="es_password",
)
+
inputs_es = input_elements.union(
{vectordb_type, es_url, es_index, es_user, es_password}
)
@@ -184,29 +168,21 @@ def create_vector_db_panel(
with gr.Column(
visible=(view_model.vectordb_type == "Milvus")
) as milvus_col:
- milvus_host = gr.Textbox(
- label="Host", value=view_model.milvus_host, elem_id="milvus_host"
- )
- milvus_port = gr.Textbox(
- label="Port", value=view_model.milvus_port, elem_id="milvus_port"
- )
- milvus_user = gr.Textbox(
- label="User", value=view_model.milvus_user, elem_id="milvus_user"
- )
+ milvus_host = gr.Textbox(label="Host", elem_id="milvus_host")
+ milvus_port = gr.Textbox(label="Port", elem_id="milvus_port")
+
+ milvus_user = gr.Textbox(label="User", elem_id="milvus_user")
milvus_password = gr.Textbox(
label="Password",
type="password",
- value=view_model.milvus_password,
elem_id="milvus_password",
)
milvus_database = gr.Textbox(
label="Database",
- value=view_model.milvus_database,
elem_id="milvus_database",
)
milvus_collection_name = gr.Textbox(
label="Collection name",
- value=view_model.milvus_collection_name,
elem_id="milvus_collection_name",
)
@@ -231,9 +207,7 @@ def create_vector_db_panel(
)
with gr.Column(visible=(view_model.vectordb_type == "FAISS")) as faiss_col:
- faiss_path = gr.Textbox(
- label="Path", value=view_model.faiss_path, elem_id="faiss_path"
- )
+ faiss_path = gr.Textbox(label="Path", elem_id="faiss_path")
connect_btn_faiss = gr.Button("Connect Faiss", variant="primary")
con_state_faiss = gr.Textbox(label="Connection Info: ")
inputs_faiss = input_elements.union({vectordb_type, faiss_path})
@@ -275,4 +249,36 @@ def change_vectordb_conn(vectordb_type):
outputs=[adb_col, holo_col, faiss_col, es_col, milvus_col],
)
- return panel
+ components.extend(
+ [
+ vectordb_type,
+ adb_ak,
+ adb_sk,
+ adb_region_id,
+ adb_instance_id,
+ adb_collection,
+ adb_account,
+ adb_account_password,
+ adb_namespace,
+ hologres_host,
+ hologres_port,
+ hologres_database,
+ hologres_user,
+ hologres_password,
+ hologres_table,
+ hologres_pre_delete,
+ milvus_host,
+ milvus_port,
+ milvus_database,
+ milvus_collection_name,
+ milvus_user,
+ milvus_password,
+ faiss_path,
+ es_url,
+ es_index,
+ es_user,
+ es_password,
+ ]
+ )
+
+ return components_to_dict(components)
diff --git a/src/pai_rag/app/web/ui.py b/src/pai_rag/app/web/ui.py
deleted file mode 100644
index b0e96d89..00000000
--- a/src/pai_rag/app/web/ui.py
+++ /dev/null
@@ -1,549 +0,0 @@
-import datetime
-import gradio as gr
-import os
-from typing import List, Any
-from pai_rag.app.web.view_model import view_model
-from pai_rag.app.web.rag_client import rag_client
-from pai_rag.app.web.vector_db_panel import create_vector_db_panel
-from pai_rag.app.web.prompts import (
- SIMPLE_PROMPTS,
- GENERAL_PROMPTS,
- EXTRACT_URL_PROMPTS,
- ACCURATE_CONTENT_PROMPTS,
- PROMPT_MAP,
-)
-
-import logging
-import traceback
-
-logger = logging.getLogger("WebUILogger")
-
-welcome_message_markdown = """
- # \N{fire} Chatbot with RAG on PAI !
- ### \N{rocket} Build your own personalized knowledge base question-answering chatbot.
-
- #### \N{fire} Platform: [PAI](https://help.aliyun.com/zh/pai) / [PAI-EAS](https://www.aliyun.com/product/bigdata/learn/eas) / [PAI-DSW](https://pai.console.aliyun.com/notebook) \N{rocket} Supported VectorStores: [Hologres](https://www.aliyun.com/product/bigdata/hologram) / [ElasticSearch](https://www.aliyun.com/product/bigdata/elasticsearch) / [AnalyticDB](https://www.aliyun.com/product/apsaradb/gpdb) / [FAISS](https://python.langchain.com/docs/integrations/vectorstores/faiss)
-
- #### \N{fire} API Docs \N{rocket} \N{fire} 欢迎加入【PAI】RAG答疑群 27370042974
- """
-
-css_style = """
- h1, h3, h4 {
- text-align: center;
- display:block;
- }
- """
-
-DEFAULT_EMBED_SIZE = 1536
-DEFAULT_HF_EMBED_MODEL = "bge-small-zh-v1.5"
-
-embedding_dim_dict = {
- "bge-small-zh-v1.5": 1024,
- "SGPT-125M-weightedmean-nli-bitfit": 768,
- "text2vec-large-chinese": 1024,
- "text2vec-base-chinese": 768,
- "paraphrase-multilingual-MiniLM-L12-v2": 384,
-}
-
-embedding_api_key_dict = {"HuggingFace": False, "OpenAI": True, "DashScope": True}
-
-llm_model_key_dict = {
- "DashScope": [
- "qwen-turbo",
- "qwen-plus",
- "qwen-max",
- "qwen-max-1201",
- "qwen-max-longcontext",
- ],
- "OpenAI": [
- "gpt-3.5-turbo",
- "gpt-4-turbo",
- ],
-}
-
-current_session_id = None
-
-
-def connect_vector_db(input_elements: List[Any]):
- try:
- update_dict = {}
- for element, value in input_elements.items():
- update_dict[element.elem_id] = value
-
- view_model.update(update_dict)
- new_config = view_model.to_app_config()
- rag_client.reload_config(new_config)
- return f"[{datetime.datetime.now()}] Connect vector db success!"
- except Exception as ex:
- logger.critical(f"[Critical] Connect failed. {traceback.format_exc()}")
- return f"Connect failed. Please check: {ex}"
-
-
-def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction):
- view_model.chunk_size = chunk_size
- view_model.chunk_overlap = chunk_overlap
- new_config = view_model.to_app_config()
- rag_client.reload_config(new_config)
-
- if not upload_files:
- return "No file selected. Please choose at least one file."
-
- for file in upload_files:
- file_dir = os.path.dirname(file.name)
- rag_client.add_knowledge(file_dir, enable_qa_extraction)
- return (
- "Upload "
- + str(len(upload_files))
- + " files Success! \n \n Relevant content has been added to the vector store, you can now start chatting and asking questions."
- )
-
-
-def clear_history(chatbot):
- chatbot = []
- global current_session_id
- current_session_id = None
- return chatbot, 0
-
-
-def respond(input_elements: List[Any]):
- global current_session_id
-
- update_dict = {}
- for element, value in input_elements.items():
- update_dict[element.elem_id] = value
-
- # empty input.
- if not update_dict["question"]:
- return "", update_dict["chatbot"], 0
-
- view_model.update(update_dict)
- new_config = view_model.to_app_config()
- rag_client.reload_config(new_config)
-
- query_type = update_dict["query_type"]
- msg = update_dict["question"]
- chatbot = update_dict["chatbot"]
-
- if query_type == "LLM":
- response = rag_client.query_llm(
- msg,
- session_id=current_session_id,
- )
-
- elif query_type == "Retrieval":
- response = rag_client.query_vector(msg)
- else:
- response = rag_client.query(msg, session_id=current_session_id)
- if update_dict["include_history"]:
- current_session_id = response.session_id
- else:
- current_session_id = None
- chatbot.append((msg, response.answer))
- return "", chatbot, 0
-
-
-def create_ui():
- with gr.Blocks(css=css_style) as homepage:
- gr.Markdown(value=welcome_message_markdown)
-
- with gr.Tab("\N{rocket} Settings"):
- with gr.Row():
- with gr.Column():
- with gr.Column():
- _ = gr.Markdown(value="**Please choose your embedding model.**")
- embed_source = gr.Dropdown(
- embedding_api_key_dict.keys(),
- label="Embedding Type",
- value=view_model.embed_source,
- elem_id="embed_source",
- )
- embed_model = gr.Dropdown(
- embedding_dim_dict.keys(),
- label="Embedding Model Name",
- value=view_model.embed_model,
- elem_id="embed_model",
- visible=(view_model.embed_source == "HuggingFace"),
- )
- embed_dim = gr.Textbox(
- label="Embedding Dimension",
- value=embedding_dim_dict.get(
- view_model.embed_model, DEFAULT_EMBED_SIZE
- ),
- elem_id="embed_dim",
- )
-
- def change_emb_source(source):
- view_model.embed_source = source
- view_model.embed_model = (
- DEFAULT_HF_EMBED_MODEL
- if source == "HuggingFace"
- else source
- )
- _embed_dim = (
- embedding_dim_dict.get(
- view_model.embed_model, DEFAULT_EMBED_SIZE
- )
- if source == "HuggingFace"
- else DEFAULT_EMBED_SIZE
- )
- return {
- embed_model: gr.update(
- visible=(source == "HuggingFace"),
- value=view_model.embed_model,
- ),
- embed_dim: _embed_dim,
- }
-
- def change_emb_model(model):
- view_model.embed_model = model
- return {
- embed_dim: embedding_dim_dict.get(
- view_model.embed_model, DEFAULT_EMBED_SIZE
- ),
- }
-
- embed_source.change(
- fn=change_emb_source,
- inputs=embed_source,
- outputs=[embed_model, embed_dim],
- )
- embed_model.change(
- fn=change_emb_model,
- inputs=embed_model,
- outputs=[embed_dim],
- )
-
- with gr.Column():
- _ = gr.Markdown(value="**Please set your LLM.**")
- llm_src = gr.Dropdown(
- ["PaiEas", "OpenAI", "DashScope"],
- label="LLM Model Source",
- value=view_model.llm,
- elem_id="llm",
- )
- with gr.Column(visible=(view_model.llm == "PaiEas")) as eas_col:
- llm_eas_url = gr.Textbox(
- label="EAS Url",
- value=view_model.llm_eas_url,
- elem_id="llm_eas_url",
- )
- llm_eas_token = gr.Textbox(
- label="EAS Token",
- value=view_model.llm_eas_token,
- elem_id="llm_eas_token",
- )
- llm_eas_model_name = gr.Textbox(
- label="EAS Model name",
- value=view_model.llm_eas_model_name,
- elem_id="llm_eas_model_name",
- )
- with gr.Column(
- visible=(view_model.llm != "PaiEas")
- ) as api_llm_col:
- llm_api_model_name = gr.Dropdown(
- llm_model_key_dict.get(view_model.llm, []),
- label="LLM Model Name",
- value=view_model.llm_api_model_name,
- elem_id="llm_api_model_name",
- )
-
- def change_llm_src(value):
- view_model.llm = value
- eas_visible = value == "PaiEas"
- api_visible = value != "PaiEas"
- model_options = llm_model_key_dict.get(value, [])
- cur_model = model_options[0] if model_options else ""
- return {
- eas_col: gr.update(visible=eas_visible),
- api_llm_col: gr.update(visible=api_visible),
- llm_api_model_name: gr.update(
- choices=model_options, value=cur_model
- ),
- }
-
- llm_src.change(
- fn=change_llm_src,
- inputs=llm_src,
- outputs=[eas_col, api_llm_col, llm_api_model_name],
- )
- """
- with gr.Column():
- _ = gr.Markdown(
- value="**(Optional) Please upload your config file.**"
- )
- config_file = gr.File(
- value=view_model.config_file,
- label="Upload a local config json file",
- file_types=[".json"],
- file_count="single",
- interactive=True,
- )
- cfg_btn = gr.Button("Parse Config", variant="primary")
- """
- create_vector_db_panel(
- view_model=view_model,
- input_elements={
- llm_src,
- llm_eas_url,
- llm_eas_token,
- llm_eas_model_name,
- embed_source,
- embed_model,
- embed_dim,
- llm_api_model_name,
- },
- connect_vector_func=connect_vector_db,
- )
-
- with gr.Tab("\N{whale} Upload"):
- with gr.Row():
- with gr.Column(scale=2):
- chunk_size = gr.Textbox(
- label="\N{rocket} Chunk Size (The size of the chunks into which a document is divided)",
- value=view_model.chunk_size,
- )
- chunk_overlap = gr.Textbox(
- label="\N{fire} Chunk Overlap (The portion of adjacent document chunks that overlap with each other)",
- value=view_model.chunk_overlap,
- )
- enable_qa_extraction = gr.Checkbox(
- label="Yes",
- info="Process with QA Extraction Model",
- value=view_model.enable_qa_extraction,
- elem_id="enable_qa_extraction",
- )
- with gr.Column(scale=8):
- with gr.Tab("Files"):
- upload_file = gr.File(
- label="Upload a knowledge file.", file_count="multiple"
- )
- upload_file_btn = gr.Button("Upload", variant="primary")
- upload_file_state = gr.Textbox(label="Upload State")
- with gr.Tab("Directory"):
- upload_file_dir = gr.File(
- label="Upload a knowledge directory.",
- file_count="directory",
- )
- upload_dir_btn = gr.Button("Upload", variant="primary")
- upload_dir_state = gr.Textbox(label="Upload State")
- upload_file_btn.click(
- fn=upload_knowledge,
- inputs=[
- upload_file,
- chunk_size,
- chunk_overlap,
- enable_qa_extraction,
- ],
- outputs=upload_file_state,
- api_name="upload_knowledge",
- )
- upload_dir_btn.click(
- fn=upload_knowledge,
- inputs=[
- upload_file_dir,
- chunk_size,
- chunk_overlap,
- enable_qa_extraction,
- ],
- outputs=upload_dir_state,
- api_name="upload_knowledge_dir",
- )
-
- with gr.Tab("\N{fire} Chat"):
- with gr.Row():
- with gr.Column(scale=2):
- query_type = gr.Radio(
- ["Retrieval", "LLM", "RAG (Retrieval + LLM)"],
- label="\N{fire} Which query do you want to use?",
- elem_id="query_type",
- value="RAG (Retrieval + LLM)",
- )
-
- with gr.Column(visible=True) as vs_col:
- vec_model_argument = gr.Accordion(
- "Parameters of Vector Retrieval"
- )
-
- with vec_model_argument:
- similarity_top_k = gr.Slider(
- minimum=0,
- maximum=100,
- step=1,
- elem_id="similarity_top_k",
- value=view_model.similarity_top_k,
- label="Top K (choose between 0 and 100)",
- )
- # similarity_cutoff = gr.Slider(minimum=0, maximum=1, step=0.01,elem_id="similarity_cutoff",value=view_model.similarity_cutoff, label="Similarity Distance Threshold (The more similar the vectors, the smaller the value.)")
- rerank_model = gr.Radio(
- [
- "no-reranker",
- "bge-reranker-base",
- "bge-reranker-large",
- "llm-reranker",
- ],
- label="Re-Rank Model (Note: It will take a long time to load the model when using it for the first time.)",
- elem_id="rerank_model",
- value=view_model.rerank_model,
- )
- retrieval_mode = gr.Radio(
- ["Embedding Only", "Keyword Ensembled", "Keyword Only"],
- label="Retrieval Mode",
- elem_id="retrieval_mode",
- value=view_model.retrieval_mode,
- )
- vec_args = {
- similarity_top_k,
- # similarity_cutoff,
- rerank_model,
- retrieval_mode,
- }
- with gr.Column(visible=True) as llm_col:
- model_argument = gr.Accordion("Inference Parameters of LLM")
- with model_argument:
- include_history = gr.Checkbox(
- label="Chat history",
- info="Query with chat history.",
- elem_id="include_history",
- )
- llm_topk = gr.Slider(
- minimum=0,
- maximum=100,
- step=1,
- value=30,
- elem_id="llm_topk",
- label="Top K (choose between 0 and 100)",
- )
- llm_topp = gr.Slider(
- minimum=0,
- maximum=1,
- step=0.01,
- value=0.8,
- elem_id="llm_topp",
- label="Top P (choose between 0 and 1)",
- )
- llm_temp = gr.Slider(
- minimum=0,
- maximum=1,
- step=0.01,
- value=0.7,
- elem_id="llm_temp",
- label="Temperature (choose between 0 and 1)",
- )
- llm_args = {llm_topk, llm_topp, llm_temp, include_history}
-
- with gr.Column(visible=True) as lc_col:
- prm_type = PROMPT_MAP.get(view_model.text_qa_template, "Custom")
- prm_radio = gr.Radio(
- [
- "Simple",
- "General",
- "Extract URL",
- "Accurate Content",
- "Custom",
- ],
- label="\N{rocket} Please choose the prompt template type",
- value=prm_type,
- )
- text_qa_template = gr.Textbox(
- label="prompt template",
- placeholder=view_model.text_qa_template,
- value=view_model.text_qa_template,
- elem_id="text_qa_template",
- lines=4,
- )
-
- def change_prompt_template(prm_radio):
- if prm_radio == "Simple":
- return {
- text_qa_template: gr.update(
- value=SIMPLE_PROMPTS, interactive=False
- )
- }
- elif prm_radio == "General":
- return {
- text_qa_template: gr.update(
- value=GENERAL_PROMPTS, interactive=False
- )
- }
- elif prm_radio == "Extract URL":
- return {
- text_qa_template: gr.update(
- value=EXTRACT_URL_PROMPTS, interactive=False
- )
- }
- elif prm_radio == "Accurate Content":
- return {
- text_qa_template: gr.update(
- value=ACCURATE_CONTENT_PROMPTS,
- interactive=False,
- )
- }
- else:
- return {
- text_qa_template: gr.update(
- value="", interactive=True
- )
- }
-
- prm_radio.change(
- fn=change_prompt_template,
- inputs=prm_radio,
- outputs=[text_qa_template],
- )
-
- cur_tokens = gr.Textbox(
- label="\N{fire} Current total count of tokens", visible=False
- )
-
- def change_query_radio(query_type):
- global current_session_id
- current_session_id = None
- if query_type == "Retrieval":
- return {
- vs_col: gr.update(visible=True),
- llm_col: gr.update(visible=False),
- lc_col: gr.update(visible=False),
- }
- elif query_type == "LLM":
- return {
- vs_col: gr.update(visible=False),
- llm_col: gr.update(visible=True),
- lc_col: gr.update(visible=False),
- }
- elif query_type == "RAG (Retrieval + LLM)":
- return {
- vs_col: gr.update(visible=True),
- llm_col: gr.update(visible=True),
- lc_col: gr.update(visible=True),
- }
-
- query_type.change(
- fn=change_query_radio,
- inputs=query_type,
- outputs=[vs_col, llm_col, lc_col],
- )
-
- with gr.Column(scale=8):
- chatbot = gr.Chatbot(height=500, elem_id="chatbot")
- question = gr.Textbox(
- label="Enter your question.", elem_id="question"
- )
- with gr.Row():
- submitBtn = gr.Button("Submit", variant="primary")
- clearBtn = gr.Button("Clear History", variant="secondary")
-
- chat_args = (
- {text_qa_template, question, query_type, chatbot}
- .union(vec_args)
- .union(llm_args)
- )
-
- submitBtn.click(
- respond,
- chat_args,
- [question, chatbot, cur_tokens],
- api_name="respond",
- )
- clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens])
-
- return homepage
diff --git a/src/pai_rag/app/web/ui_constants.py b/src/pai_rag/app/web/ui_constants.py
new file mode 100644
index 00000000..69946f6e
--- /dev/null
+++ b/src/pai_rag/app/web/ui_constants.py
@@ -0,0 +1,55 @@
+SIMPLE_PROMPTS = "参考内容如下:\n{context_str}\n作为个人知识答疑助手,请根据上述参考内容回答下面问题,答案中不允许包含编造内容。\n用户问题:\n{query_str}"
+GENERAL_PROMPTS = '基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context_str}\n=====\n用户问题:\n{query_str}'
+EXTRACT_URL_PROMPTS = "你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"
+ACCURATE_CONTENT_PROMPTS = "你是一位知识小助手,请根据下面我提供的知识库中相关知识,对我提出的若干问题进行回答,同时回答的内容需满足我所提的要求! \n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"
+
+PROMPT_MAP = {
+ SIMPLE_PROMPTS: "Simple",
+ GENERAL_PROMPTS: "General",
+ EXTRACT_URL_PROMPTS: "Extract URL",
+ ACCURATE_CONTENT_PROMPTS: "Accurate Content",
+}
+
+WELCOME_MESSAGE = """
+ # \N{fire} Chatbot with RAG on PAI !
+ ### \N{rocket} Build your own personalized knowledge base question-answering chatbot.
+
+ #### \N{fire} Platform: [PAI](https://help.aliyun.com/zh/pai) / [PAI-EAS](https://www.aliyun.com/product/bigdata/learn/eas) / [PAI-DSW](https://pai.console.aliyun.com/notebook) \N{rocket} Supported VectorStores: [Hologres](https://www.aliyun.com/product/bigdata/hologram) / [ElasticSearch](https://www.aliyun.com/product/bigdata/elasticsearch) / [AnalyticDB](https://www.aliyun.com/product/apsaradb/gpdb) / [FAISS](https://python.langchain.com/docs/integrations/vectorstores/faiss)
+
+ #### \N{fire} API Docs \N{rocket} \N{fire} 欢迎加入【PAI】RAG答疑群 27370042974
+ """
+
+DEFAULT_CSS_STYPE = """
+ h1, h3, h4 {
+ text-align: center;
+ display:block;
+ }
+ """
+
+DEFAULT_EMBED_SIZE = 1536
+
+EMBEDDING_DIM_DICT = {
+ "bge-small-zh-v1.5": 1024,
+ "SGPT-125M-weightedmean-nli-bitfit": 768,
+ "text2vec-large-chinese": 1024,
+ "text2vec-base-chinese": 768,
+ "paraphrase-multilingual-MiniLM-L12-v2": 384,
+}
+
+DEFAULT_HF_EMBED_MODEL = "bge-small-zh-v1.5"
+
+EMBEDDING_API_KEY_DICT = {"HuggingFace": False, "OpenAI": True, "DashScope": True}
+
+LLM_MODEL_KEY_DICT = {
+ "DashScope": [
+ "qwen-turbo",
+ "qwen-plus",
+ "qwen-max",
+ "qwen-max-1201",
+ "qwen-max-longcontext",
+ ],
+ "OpenAI": [
+ "gpt-3.5-turbo",
+ "gpt-4-turbo",
+ ],
+}
diff --git a/src/pai_rag/app/web/utils.py b/src/pai_rag/app/web/utils.py
new file mode 100644
index 00000000..86bbe30c
--- /dev/null
+++ b/src/pai_rag/app/web/utils.py
@@ -0,0 +1,5 @@
+from typing import List, Any, Dict
+
+
+def components_to_dict(components: List[Any]) -> Dict[str, Any]:
+ return {c.elem_id: c for c in components}
diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py
index d095e304..598a1b1b 100644
--- a/src/pai_rag/app/web/view_model.py
+++ b/src/pai_rag/app/web/view_model.py
@@ -1,6 +1,13 @@
from pydantic import BaseModel
from typing import Any, Dict
from collections import defaultdict
+from pai_rag.app.web.ui_constants import (
+ EMBEDDING_DIM_DICT,
+ DEFAULT_EMBED_SIZE,
+ DEFAULT_HF_EMBED_MODEL,
+ LLM_MODEL_KEY_DICT,
+ PROMPT_MAP,
+)
def recursive_dict():
@@ -17,7 +24,7 @@ def _transform_to_dict(config):
class ViewModel(BaseModel):
# embedding
embed_source: str = "HuggingFace"
- embed_model: str = "bge-small-zh-v1.5"
+ embed_model: str = DEFAULT_HF_EMBED_MODEL
embed_dim: int = 1024
embed_api_key: str = None
@@ -299,5 +306,77 @@ def to_app_config(self):
return _transform_to_dict(config)
+ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
+ settings = {}
+ settings["embed_source"] = {"value": self.embed_source}
+ settings["embed_model"] = {
+ "value": self.embed_model,
+ "visible": self.embed_source == "HuggingFace",
+ }
+ settings["embed_dim"] = {
+ "value": EMBEDDING_DIM_DICT.get(self.embed_model, DEFAULT_EMBED_SIZE)
+ if self.embed_source == "HuggingFace"
+ else DEFAULT_EMBED_SIZE
+ }
+
+ settings["llm"] = {"value": self.llm}
+ settings["llm_eas_url"] = {"value": self.llm_eas_url}
+ settings["llm_eas_token"] = {"value": self.llm_eas_token}
+ settings["llm_eas_model_name"] = {"value": self.llm_eas_model_name}
+ settings["llm_api_model_name"] = {
+ "value": self.llm_api_model_name,
+ "choices": LLM_MODEL_KEY_DICT.get(self.llm, []),
+ }
+ settings["chunk_size"] = {"value": self.chunk_size}
+ settings["chunk_overlap"] = {"value": self.chunk_overlap}
+ settings["enable_qa_extraction"] = {"value": self.enable_qa_extraction}
+ settings["similarity_top_k"] = {"value": self.similarity_top_k}
+ settings["rerank_model"] = {"value": self.rerank_model}
+ settings["retrieval_mode"] = {"value": self.retrieval_mode}
+
+ prm_type = PROMPT_MAP.get(view_model.text_qa_template, "Custom")
+ settings["prm_type"] = {"value": prm_type}
+ settings["text_qa_template"] = {"value": self.text_qa_template}
+
+ settings["vectordb_type"] = {"value": self.vectordb_type}
+
+ # adb
+ settings["adb_ak"] = {"value": self.adb_ak}
+ settings["adb_sk"] = {"value": self.adb_sk}
+ settings["adb_region_id"] = {"value": self.adb_region_id}
+ settings["adb_account"] = {"value": self.adb_account}
+ settings["adb_account_password"] = {"value": self.adb_account_password}
+ settings["adb_namespace"] = {"value": self.adb_namespace}
+ settings["adb_instance_id"] = {"value": self.adb_instance_id}
+ settings["adb_collection"] = {"value": self.adb_collection}
+
+ # hologres
+ settings["hologres_host"] = {"value": self.hologres_host}
+ settings["hologres_database"] = {"value": self.hologres_database}
+ settings["hologres_port"] = {"value": self.hologres_port}
+ settings["hologres_user"] = {"value": self.hologres_user}
+ settings["hologres_password"] = {"value": self.hologres_password}
+ settings["hologres_table"] = {"value": self.hologres_table}
+ settings["hologres_pre_delete"] = {"value": self.hologres_pre_delete}
+
+ # elasticsearch
+ settings["es_url"] = {"value": self.es_url}
+ settings["es_index"] = {"value": self.es_index}
+ settings["es_user"] = {"value": self.es_user}
+ settings["es_password"] = {"value": self.es_password}
+
+ # milvus
+ settings["milvus_host"] = {"value": self.milvus_host}
+ settings["milvus_port"] = {"value": self.milvus_port}
+ settings["milvus_database"] = {"value": self.milvus_database}
+ settings["milvus_user"] = {"value": self.milvus_user}
+ settings["milvus_password"] = {"value": self.milvus_password}
+ settings["milvus_collection_name"] = {"value": self.milvus_collection_name}
+
+ # faiss
+ settings["faiss_path"] = {"value": self.faiss_path}
+
+ return settings
+
view_model = ViewModel()
diff --git a/src/pai_rag/app/web/webui.py b/src/pai_rag/app/web/webui.py
new file mode 100644
index 00000000..c920968a
--- /dev/null
+++ b/src/pai_rag/app/web/webui.py
@@ -0,0 +1,48 @@
+import gradio as gr
+from pai_rag.app.web.view_model import view_model
+from pai_rag.app.web.tabs.settings_tab import create_setting_tab
+from pai_rag.app.web.tabs.upload_tab import create_upload_tab
+from pai_rag.app.web.tabs.chat_tab import create_chat_tab
+from pai_rag.app.web.element_manager import elem_manager
+from pai_rag.app.web.ui_constants import (
+ DEFAULT_CSS_STYPE,
+ WELCOME_MESSAGE,
+)
+
+import logging
+
+logger = logging.getLogger("WebUILogger")
+
+
+def resume_ui():
+ outputs = {}
+ component_settings = view_model.to_component_settings()
+
+ for elem in elem_manager.get_elem_list():
+ elem_id = elem.elem_id
+ elem_attr = component_settings[elem_id]
+ elem = elem_manager.get_elem_by_id(elem_id=elem_id)
+
+ # For gradio version 3.41.0, we can remove .value for latest gradio here.
+ outputs[elem] = elem.__class__(**elem_attr).value
+
+ return outputs
+
+
+def create_ui():
+ with gr.Blocks(css=DEFAULT_CSS_STYPE) as homepage:
+ # generate components
+ gr.Markdown(value=WELCOME_MESSAGE)
+ with gr.Tab("\N{rocket} Settings"):
+ setting_elements = create_setting_tab()
+ elem_manager.add_elems(setting_elements)
+ with gr.Tab("\N{whale} Upload"):
+ upload_elements = create_upload_tab()
+ elem_manager.add_elems(upload_elements)
+ with gr.Tab("\N{fire} Chat"):
+ chat_elements = create_chat_tab()
+ elem_manager.add_elems(chat_elements)
+ homepage.load(
+ resume_ui, outputs=elem_manager.get_elem_list(), concurrency_limit=None
+ )
+ return homepage