From 1705a9f33d77eaf304fa021319988bdec204a4b9 Mon Sep 17 00:00:00 2001 From: Yue Fei <59813791+moria97@users.noreply.github.com> Date: Tue, 4 Jun 2024 20:50:03 +0800 Subject: [PATCH] feat: enable Gradio page to share state between tabs (#48) * Remove entrypoint * Update ui * Revert dockerfile change * Delete debug clauses * Fix embed dim value * Update view model --- src/pai_rag/app/app.py | 3 +- src/pai_rag/app/web/element_manager.py | 48 ++ src/pai_rag/app/web/prompts.py | 11 - src/pai_rag/app/web/tabs/chat_tab.py | 249 ++++++++ src/pai_rag/app/web/tabs/settings_tab.py | 175 ++++++ src/pai_rag/app/web/tabs/upload_tab.py | 84 +++ .../app/web/{ => tabs}/vector_db_panel.py | 110 ++-- src/pai_rag/app/web/ui.py | 549 ------------------ src/pai_rag/app/web/ui_constants.py | 55 ++ src/pai_rag/app/web/utils.py | 5 + src/pai_rag/app/web/view_model.py | 81 ++- src/pai_rag/app/web/webui.py | 48 ++ 12 files changed, 803 insertions(+), 615 deletions(-) create mode 100644 src/pai_rag/app/web/element_manager.py delete mode 100644 src/pai_rag/app/web/prompts.py create mode 100644 src/pai_rag/app/web/tabs/chat_tab.py create mode 100644 src/pai_rag/app/web/tabs/settings_tab.py create mode 100644 src/pai_rag/app/web/tabs/upload_tab.py rename src/pai_rag/app/web/{ => tabs}/vector_db_panel.py (79%) delete mode 100644 src/pai_rag/app/web/ui.py create mode 100644 src/pai_rag/app/web/ui_constants.py create mode 100644 src/pai_rag/app/web/utils.py create mode 100644 src/pai_rag/app/web/webui.py diff --git a/src/pai_rag/app/app.py b/src/pai_rag/app/app.py index e4c418ea..4db4797e 100644 --- a/src/pai_rag/app/app.py +++ b/src/pai_rag/app/app.py @@ -3,10 +3,9 @@ from pai_rag.core.rag_service import rag_service from pai_rag.app.api import query from pai_rag.app.api.middleware import init_middleware -from pai_rag.app.web.ui import create_ui +from pai_rag.app.web.webui import create_ui from pai_rag.app.web.rag_client import rag_client -# UI_PATH = "/ui" UI_PATH = "" diff --git a/src/pai_rag/app/web/element_manager.py b/src/pai_rag/app/web/element_manager.py new file mode 100644 index 00000000..d374bbb5 --- /dev/null +++ b/src/pai_rag/app/web/element_manager.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, Dict, Generator, List, Tuple + +if TYPE_CHECKING: + from gradio.components import Component + + +class ElementManager: + def __init__(self) -> None: + self._id_to_elem: Dict[str, "Component"] = {} + self._elem_to_id: Dict["Component", str] = {} + + def add_elems(self, elem_dict: Dict[str, "Component"]) -> None: + r""" + Adds elements to manager. + """ + for elem_id, elem in elem_dict.items(): + self._id_to_elem[elem_id] = elem + self._elem_to_id[elem] = elem_id + + def get_elem_list(self) -> List["Component"]: + r""" + Returns the list of all elements. + """ + return list(self._id_to_elem.values()) + + def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: + r""" + Returns an iterator over all elements with their names. + """ + for elem_id, elem in self._id_to_elem.items(): + yield elem_id.split(".")[-1], elem + + def get_elem_by_id(self, elem_id: str) -> "Component": + r""" + Gets element by id. + + Example: top.lang, train.dataset + """ + return self._id_to_elem[elem_id] + + def get_id_by_elem(self, elem: "Component") -> str: + r""" + Gets id by element. + """ + return self._elem_to_id[elem] + + +elem_manager = ElementManager() diff --git a/src/pai_rag/app/web/prompts.py b/src/pai_rag/app/web/prompts.py deleted file mode 100644 index 2b6be084..00000000 --- a/src/pai_rag/app/web/prompts.py +++ /dev/null @@ -1,11 +0,0 @@ -SIMPLE_PROMPTS = "参考内容如下:\n{context_str}\n作为个人知识答疑助手,请根据上述参考内容回答下面问题,答案中不允许包含编造内容。\n用户问题:\n{query_str}" -GENERAL_PROMPTS = '基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context_str}\n=====\n用户问题:\n{query_str}' -EXTRACT_URL_PROMPTS = "你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}" -ACCURATE_CONTENT_PROMPTS = "你是一位知识小助手,请根据下面我提供的知识库中相关知识,对我提出的若干问题进行回答,同时回答的内容需满足我所提的要求! \n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}" - -PROMPT_MAP = { - SIMPLE_PROMPTS: "Simple", - GENERAL_PROMPTS: "General", - EXTRACT_URL_PROMPTS: "Extract URL", - ACCURATE_CONTENT_PROMPTS: "Accurate Content", -} diff --git a/src/pai_rag/app/web/tabs/chat_tab.py b/src/pai_rag/app/web/tabs/chat_tab.py new file mode 100644 index 00000000..5ada0115 --- /dev/null +++ b/src/pai_rag/app/web/tabs/chat_tab.py @@ -0,0 +1,249 @@ +from typing import Dict, Any, List +import gradio as gr +from pai_rag.app.web.rag_client import rag_client +from pai_rag.app.web.view_model import view_model +from pai_rag.app.web.ui_constants import ( + SIMPLE_PROMPTS, + GENERAL_PROMPTS, + EXTRACT_URL_PROMPTS, + ACCURATE_CONTENT_PROMPTS, +) + + +current_session_id = None + + +def clear_history(chatbot): + chatbot = [] + global current_session_id + current_session_id = None + return chatbot, 0 + + +def respond(input_elements: List[Any]): + global current_session_id + + update_dict = {} + for element, value in input_elements.items(): + update_dict[element.elem_id] = value + + # empty input. + if not update_dict["question"]: + return "", update_dict["chatbot"], 0 + + view_model.update(update_dict) + new_config = view_model.to_app_config() + rag_client.reload_config(new_config) + + query_type = update_dict["query_type"] + msg = update_dict["question"] + chatbot = update_dict["chatbot"] + + if query_type == "LLM": + response = rag_client.query_llm( + msg, + session_id=current_session_id, + ) + + elif query_type == "Retrieval": + response = rag_client.query_vector(msg) + else: + response = rag_client.query(msg, session_id=current_session_id) + if update_dict["include_history"]: + current_session_id = response.session_id + else: + current_session_id = None + chatbot.append((msg, response.answer)) + return "", chatbot, 0 + + +def create_chat_tab() -> Dict[str, Any]: + with gr.Row(): + with gr.Column(scale=2): + query_type = gr.Radio( + ["Retrieval", "LLM", "RAG (Retrieval + LLM)"], + label="\N{fire} Which query do you want to use?", + elem_id="query_type", + value="RAG (Retrieval + LLM)", + ) + + with gr.Column(visible=True) as vs_col: + vec_model_argument = gr.Accordion("Parameters of Vector Retrieval") + + with vec_model_argument: + similarity_top_k = gr.Slider( + minimum=0, + maximum=100, + step=1, + elem_id="similarity_top_k", + label="Top K (choose between 0 and 100)", + ) + # similarity_cutoff = gr.Slider(minimum=0, maximum=1, step=0.01,elem_id="similarity_cutoff",value=view_model.similarity_cutoff, label="Similarity Distance Threshold (The more similar the vectors, the smaller the value.)") + rerank_model = gr.Radio( + [ + "no-reranker", + "bge-reranker-base", + "bge-reranker-large", + "llm-reranker", + ], + label="Re-Rank Model (Note: It will take a long time to load the model when using it for the first time.)", + elem_id="rerank_model", + ) + retrieval_mode = gr.Radio( + ["Embedding Only", "Keyword Ensembled", "Keyword Only"], + label="Retrieval Mode", + elem_id="retrieval_mode", + ) + vec_args = { + similarity_top_k, + # similarity_cutoff, + rerank_model, + retrieval_mode, + } + with gr.Column(visible=True) as llm_col: + model_argument = gr.Accordion("Inference Parameters of LLM") + with model_argument: + include_history = gr.Checkbox( + label="Chat history", + info="Query with chat history.", + elem_id="include_history", + ) + llm_topk = gr.Slider( + minimum=0, + maximum=100, + step=1, + value=30, + elem_id="llm_topk", + label="Top K (choose between 0 and 100)", + ) + llm_topp = gr.Slider( + minimum=0, + maximum=1, + step=0.01, + value=0.8, + elem_id="llm_topp", + label="Top P (choose between 0 and 1)", + ) + llm_temp = gr.Slider( + minimum=0, + maximum=1, + step=0.01, + value=0.7, + elem_id="llm_temp", + label="Temperature (choose between 0 and 1)", + ) + llm_args = {llm_topk, llm_topp, llm_temp, include_history} + + with gr.Column(visible=True) as lc_col: + prm_type = gr.Radio( + [ + "Simple", + "General", + "Extract URL", + "Accurate Content", + "Custom", + ], + label="\N{rocket} Please choose the prompt template type", + elem_id="prm_type", + ) + text_qa_template = gr.Textbox( + label="prompt template", + placeholder="", + elem_id="text_qa_template", + lines=4, + ) + + def change_prompt_template(prm_type): + if prm_type == "Simple": + return { + text_qa_template: gr.update( + value=SIMPLE_PROMPTS, interactive=False + ) + } + elif prm_type == "General": + return { + text_qa_template: gr.update( + value=GENERAL_PROMPTS, interactive=False + ) + } + elif prm_type == "Extract URL": + return { + text_qa_template: gr.update( + value=EXTRACT_URL_PROMPTS, interactive=False + ) + } + elif prm_type == "Accurate Content": + return { + text_qa_template: gr.update( + value=ACCURATE_CONTENT_PROMPTS, + interactive=False, + ) + } + else: + return {text_qa_template: gr.update(value="", interactive=True)} + + prm_type.change( + fn=change_prompt_template, + inputs=prm_type, + outputs=[text_qa_template], + ) + + cur_tokens = gr.Textbox( + label="\N{fire} Current total count of tokens", visible=False + ) + + def change_query_radio(query_type): + global current_session_id + current_session_id = None + if query_type == "Retrieval": + return { + vs_col: gr.update(visible=True), + llm_col: gr.update(visible=False), + lc_col: gr.update(visible=False), + } + elif query_type == "LLM": + return { + vs_col: gr.update(visible=False), + llm_col: gr.update(visible=True), + lc_col: gr.update(visible=False), + } + elif query_type == "RAG (Retrieval + LLM)": + return { + vs_col: gr.update(visible=True), + llm_col: gr.update(visible=True), + lc_col: gr.update(visible=True), + } + + query_type.change( + fn=change_query_radio, + inputs=query_type, + outputs=[vs_col, llm_col, lc_col], + ) + + with gr.Column(scale=8): + chatbot = gr.Chatbot(height=500, elem_id="chatbot") + question = gr.Textbox(label="Enter your question.", elem_id="question") + with gr.Row(): + submitBtn = gr.Button("Submit", variant="primary") + clearBtn = gr.Button("Clear History", variant="secondary") + + chat_args = ( + {text_qa_template, question, query_type, chatbot} + .union(vec_args) + .union(llm_args) + ) + + submitBtn.click( + respond, + chat_args, + [question, chatbot, cur_tokens], + api_name="respond", + ) + clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens]) + return { + similarity_top_k.elem_id: similarity_top_k, + rerank_model.elem_id: rerank_model, + retrieval_mode.elem_id: retrieval_mode, + prm_type.elem_id: prm_type, + text_qa_template.elem_id: text_qa_template, + } diff --git a/src/pai_rag/app/web/tabs/settings_tab.py b/src/pai_rag/app/web/tabs/settings_tab.py new file mode 100644 index 00000000..f7614903 --- /dev/null +++ b/src/pai_rag/app/web/tabs/settings_tab.py @@ -0,0 +1,175 @@ +from typing import Dict, Any, List +import gradio as gr +import datetime +import traceback +from pai_rag.app.web.ui_constants import ( + EMBEDDING_API_KEY_DICT, + DEFAULT_EMBED_SIZE, + EMBEDDING_DIM_DICT, + LLM_MODEL_KEY_DICT, +) +from pai_rag.app.web.rag_client import rag_client +from pai_rag.app.web.view_model import view_model +from pai_rag.app.web.utils import components_to_dict +from pai_rag.app.web.tabs.vector_db_panel import create_vector_db_panel +import logging + +logger = logging.getLogger(__name__) + + +def connect_vector_db(input_elements: List[Any]): + try: + update_dict = {} + for element, value in input_elements.items(): + update_dict[element.elem_id] = value + + view_model.update(update_dict) + new_config = view_model.to_app_config() + rag_client.reload_config(new_config) + return f"[{datetime.datetime.now()}] Connect vector db success!" + except Exception as ex: + logger.critical(f"[Critical] Connect failed. {traceback.format_exc()}") + return f"Connect failed. Please check: {ex}" + + +def create_setting_tab() -> Dict[str, Any]: + components = [] + with gr.Row(): + with gr.Column(): + with gr.Column(): + _ = gr.Markdown(value="**Please choose your embedding model.**") + embed_source = gr.Dropdown( + EMBEDDING_API_KEY_DICT.keys(), + label="Embedding Type", + elem_id="embed_source", + ) + embed_model = gr.Dropdown( + EMBEDDING_DIM_DICT.keys(), + label="Embedding Model Name", + elem_id="embed_model", + visible=False, + ) + embed_dim = gr.Textbox( + label="Embedding Dimension", + elem_id="embed_dim", + ) + + def change_emb_source(source): + view_model.embed_source = source + return { + embed_model: gr.update(visible=(source == "HuggingFace")), + embed_dim: EMBEDDING_DIM_DICT.get( + view_model.embed_model, DEFAULT_EMBED_SIZE + ) + if source == "HuggingFace" + else DEFAULT_EMBED_SIZE, + } + + def change_emb_model(model): + view_model.embed_model = model + return { + embed_dim: EMBEDDING_DIM_DICT.get( + view_model.embed_model, DEFAULT_EMBED_SIZE + ) + if view_model.embed_source == "HuggingFace" + else DEFAULT_EMBED_SIZE, + } + + embed_source.change( + fn=change_emb_source, + inputs=embed_source, + outputs=[embed_model, embed_dim], + ) + embed_model.change( + fn=change_emb_model, + inputs=embed_model, + outputs=[embed_dim], + ) + components.extend([embed_source, embed_dim, embed_model]) + + with gr.Column(): + _ = gr.Markdown(value="**Please set your LLM.**") + llm = gr.Dropdown( + ["PaiEas", "OpenAI", "DashScope"], + label="LLM Model Source", + elem_id="llm", + ) + with gr.Column(visible=(view_model.llm == "PaiEas")) as eas_col: + llm_eas_url = gr.Textbox( + label="EAS Url", + elem_id="llm_eas_url", + ) + llm_eas_token = gr.Textbox( + label="EAS Token", + elem_id="llm_eas_token", + ) + llm_eas_model_name = gr.Textbox( + label="EAS Model name", + elem_id="llm_eas_model_name", + ) + with gr.Column(visible=(view_model.llm != "PaiEas")) as api_llm_col: + llm_api_model_name = gr.Dropdown( + label="LLM Model Name", + elem_id="llm_api_model_name", + ) + + components.extend( + [ + llm, + llm_eas_url, + llm_eas_token, + llm_eas_model_name, + llm_api_model_name, + ] + ) + + def change_llm(value): + view_model.llm = value + eas_visible = value == "PaiEas" + api_visible = value != "PaiEas" + model_options = LLM_MODEL_KEY_DICT.get(value, []) + cur_model = model_options[0] if model_options else "" + return { + eas_col: gr.update(visible=eas_visible), + api_llm_col: gr.update(visible=api_visible), + llm_api_model_name: gr.update( + choices=model_options, value=cur_model + ), + } + + llm.change( + fn=change_llm, + inputs=llm, + outputs=[eas_col, api_llm_col, llm_api_model_name], + ) + """ + with gr.Column(): + _ = gr.Markdown( + value="**(Optional) Please upload your config file.**" + ) + config_file = gr.File( + value=view_model.config_file, + label="Upload a local config json file", + file_types=[".json"], + file_count="single", + interactive=True, + ) + cfg_btn = gr.Button("Parse Config", variant="primary") + """ + vector_db_elems = create_vector_db_panel( + input_elements={ + llm, + llm_eas_url, + llm_eas_token, + llm_eas_model_name, + embed_source, + embed_model, + embed_dim, + llm_api_model_name, + }, + connect_vector_func=connect_vector_db, + ) + + elems = components_to_dict(components) + elems.update(vector_db_elems) + return elems diff --git a/src/pai_rag/app/web/tabs/upload_tab.py b/src/pai_rag/app/web/tabs/upload_tab.py new file mode 100644 index 00000000..8dba2c6b --- /dev/null +++ b/src/pai_rag/app/web/tabs/upload_tab.py @@ -0,0 +1,84 @@ +import os +from typing import Dict, Any +import gradio as gr +from pai_rag.app.web.rag_client import rag_client +from pai_rag.app.web.view_model import view_model + + +def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction): + view_model.chunk_size = chunk_size + view_model.chunk_overlap = chunk_overlap + new_config = view_model.to_app_config() + rag_client.reload_config(new_config) + + if not upload_files: + return "No file selected. Please choose at least one file." + + for file in upload_files: + file_dir = os.path.dirname(file.name) + rag_client.add_knowledge(file_dir, enable_qa_extraction) + return ( + "Upload " + + str(len(upload_files)) + + " files Success! \n \n Relevant content has been added to the vector store, you can now start chatting and asking questions." + ) + + +def create_upload_tab() -> Dict[str, Any]: + with gr.Row(): + with gr.Column(scale=2): + chunk_size = gr.Textbox( + label="\N{rocket} Chunk Size (The size of the chunks into which a document is divided)", + elem_id="chunk_size", + ) + + chunk_overlap = gr.Textbox( + label="\N{fire} Chunk Overlap (The portion of adjacent document chunks that overlap with each other)", + elem_id="chunk_overlap", + ) + enable_qa_extraction = gr.Checkbox( + label="Yes", + info="Process with QA Extraction Model", + elem_id="enable_qa_extraction", + ) + with gr.Column(scale=8): + with gr.Tab("Files"): + upload_file = gr.File( + label="Upload a knowledge file.", file_count="multiple" + ) + upload_file_btn = gr.Button("Upload", variant="primary") + upload_file_state = gr.Textbox(label="Upload State") + with gr.Tab("Directory"): + upload_file_dir = gr.File( + label="Upload a knowledge directory.", + file_count="directory", + ) + upload_dir_btn = gr.Button("Upload", variant="primary") + upload_dir_state = gr.Textbox(label="Upload State") + upload_file_btn.click( + fn=upload_knowledge, + inputs=[ + upload_file, + chunk_size, + chunk_overlap, + enable_qa_extraction, + ], + outputs=upload_file_state, + api_name="upload_knowledge", + ) + upload_dir_btn.click( + fn=upload_knowledge, + inputs=[ + upload_file_dir, + chunk_size, + chunk_overlap, + enable_qa_extraction, + ], + outputs=upload_dir_state, + api_name="upload_knowledge_dir", + ) + return { + chunk_size.elem_id: chunk_size, + chunk_overlap.elem_id: chunk_overlap, + enable_qa_extraction.elem_id: enable_qa_extraction, + } diff --git a/src/pai_rag/app/web/vector_db_panel.py b/src/pai_rag/app/web/tabs/vector_db_panel.py similarity index 79% rename from src/pai_rag/app/web/vector_db_panel.py rename to src/pai_rag/app/web/tabs/vector_db_panel.py index 8b7db86e..fc66df8a 100644 --- a/src/pai_rag/app/web/vector_db_panel.py +++ b/src/pai_rag/app/web/tabs/vector_db_panel.py @@ -1,14 +1,15 @@ import gradio as gr -from typing import Any, Set, Callable -from pai_rag.app.web.view_model import ViewModel +from typing import Any, Set, Callable, Dict +from pai_rag.app.web.view_model import view_model +from pai_rag.app.web.utils import components_to_dict def create_vector_db_panel( - view_model: ViewModel, input_elements: Set[Any], connect_vector_func: Callable[[Any], str], -): - with gr.Column() as panel: +) -> Dict[str, Any]: + components = [] + with gr.Column(): with gr.Column(): _ = gr.Markdown( value=f"**Please check your Vector Store for {view_model.vectordb_type}.**" @@ -16,7 +17,6 @@ def create_vector_db_panel( vectordb_type = gr.Dropdown( ["Hologres", "Milvus", "ElasticSearch", "AnalyticDB", "FAISS"], label="Which VectorStore do you want to use?", - value=view_model.vectordb_type, elem_id="vectordb_type", ) # Adb @@ -26,13 +26,11 @@ def create_vector_db_panel( adb_ak = gr.Textbox( label="access-key-id", type="password", - value=view_model.adb_ak, elem_id="adb_ak", ) adb_sk = gr.Textbox( label="access-key-secret", type="password", - value=view_model.adb_sk, elem_id="adb_sk", ) adb_region_id = gr.Dropdown( @@ -42,35 +40,28 @@ def create_vector_db_panel( "cn-zhangjiakou", "cn-huhehaote", "cn-shanghai", - " cn-shenzhen", + "cn-shenzhen", "cn-chengdu", ], label="RegionId", - value=view_model.adb_region_id, elem_id="adb_region_id", ) adb_instance_id = gr.Textbox( label="InstanceId", - value=view_model.adb_instance_id, elem_id="adb_instance_id", ) - adb_account = gr.Textbox( - label="Account", value=view_model.adb_account, elem_id="adb_account" - ) + adb_account = gr.Textbox(label="Account", elem_id="adb_account") adb_account_password = gr.Textbox( label="Password", type="password", - value=view_model.adb_account_password, elem_id="adb_account_password", ) adb_namespace = gr.Textbox( label="Namespace", - value=view_model.adb_namespace, elem_id="adb_namespace", ) adb_collection = gr.Textbox( label="CollectionName", - value=view_model.adb_collection, elem_id="adb_collection", ) @@ -101,36 +92,35 @@ def create_vector_db_panel( ) as holo_col: hologres_host = gr.Textbox( label="Host", - value=view_model.hologres_host, elem_id="hologres_host", ) - hologres_user = gr.Textbox( - label="User", - value=view_model.hologres_user, - elem_id="hologres_user", + hologres_port = gr.Textbox( + label="Port", + elem_id="hologres_port", ) hologres_database = gr.Textbox( label="Database", - value=view_model.hologres_database, elem_id="hologres_database", ) + hologres_user = gr.Textbox( + label="User", + elem_id="hologres_user", + ) hologres_password = gr.Textbox( label="Password", type="password", - value=view_model.hologres_password, elem_id="hologres_password", ) hologres_table = gr.Textbox( label="Table", - value=view_model.hologres_table, elem_id="hologres_table", ) hologres_pre_delete = gr.Dropdown( ["True", "False"], label="Pre Delete", - value=view_model.hologres_pre_delete, elem_id="hologres_pre_delete", ) + connect_btn_hologres = gr.Button("Connect Hologres", variant="primary") con_state_hologres = gr.Textbox(label="Connection Info: ") inputs_hologres = input_elements.union( @@ -154,21 +144,15 @@ def create_vector_db_panel( with gr.Column( visible=(view_model.vectordb_type == "ElasticSearch") ) as es_col: - es_url = gr.Textbox( - label="ElasticSearch Url", value=view_model.es_url, elem_id="es_url" - ) - es_index = gr.Textbox( - label="Index Name", value=view_model.es_index, elem_id="es_index" - ) - es_user = gr.Textbox( - label="ES User", value=view_model.es_user, elem_id="es_user" - ) + es_url = gr.Textbox(label="ElasticSearch Url", elem_id="es_url") + es_index = gr.Textbox(label="Index Name", elem_id="es_index") + es_user = gr.Textbox(label="ES User", elem_id="es_user") es_password = gr.Textbox( label="ES password", type="password", - value=view_model.es_password, elem_id="es_password", ) + inputs_es = input_elements.union( {vectordb_type, es_url, es_index, es_user, es_password} ) @@ -184,29 +168,21 @@ def create_vector_db_panel( with gr.Column( visible=(view_model.vectordb_type == "Milvus") ) as milvus_col: - milvus_host = gr.Textbox( - label="Host", value=view_model.milvus_host, elem_id="milvus_host" - ) - milvus_port = gr.Textbox( - label="Port", value=view_model.milvus_port, elem_id="milvus_port" - ) - milvus_user = gr.Textbox( - label="User", value=view_model.milvus_user, elem_id="milvus_user" - ) + milvus_host = gr.Textbox(label="Host", elem_id="milvus_host") + milvus_port = gr.Textbox(label="Port", elem_id="milvus_port") + + milvus_user = gr.Textbox(label="User", elem_id="milvus_user") milvus_password = gr.Textbox( label="Password", type="password", - value=view_model.milvus_password, elem_id="milvus_password", ) milvus_database = gr.Textbox( label="Database", - value=view_model.milvus_database, elem_id="milvus_database", ) milvus_collection_name = gr.Textbox( label="Collection name", - value=view_model.milvus_collection_name, elem_id="milvus_collection_name", ) @@ -231,9 +207,7 @@ def create_vector_db_panel( ) with gr.Column(visible=(view_model.vectordb_type == "FAISS")) as faiss_col: - faiss_path = gr.Textbox( - label="Path", value=view_model.faiss_path, elem_id="faiss_path" - ) + faiss_path = gr.Textbox(label="Path", elem_id="faiss_path") connect_btn_faiss = gr.Button("Connect Faiss", variant="primary") con_state_faiss = gr.Textbox(label="Connection Info: ") inputs_faiss = input_elements.union({vectordb_type, faiss_path}) @@ -275,4 +249,36 @@ def change_vectordb_conn(vectordb_type): outputs=[adb_col, holo_col, faiss_col, es_col, milvus_col], ) - return panel + components.extend( + [ + vectordb_type, + adb_ak, + adb_sk, + adb_region_id, + adb_instance_id, + adb_collection, + adb_account, + adb_account_password, + adb_namespace, + hologres_host, + hologres_port, + hologres_database, + hologres_user, + hologres_password, + hologres_table, + hologres_pre_delete, + milvus_host, + milvus_port, + milvus_database, + milvus_collection_name, + milvus_user, + milvus_password, + faiss_path, + es_url, + es_index, + es_user, + es_password, + ] + ) + + return components_to_dict(components) diff --git a/src/pai_rag/app/web/ui.py b/src/pai_rag/app/web/ui.py deleted file mode 100644 index b0e96d89..00000000 --- a/src/pai_rag/app/web/ui.py +++ /dev/null @@ -1,549 +0,0 @@ -import datetime -import gradio as gr -import os -from typing import List, Any -from pai_rag.app.web.view_model import view_model -from pai_rag.app.web.rag_client import rag_client -from pai_rag.app.web.vector_db_panel import create_vector_db_panel -from pai_rag.app.web.prompts import ( - SIMPLE_PROMPTS, - GENERAL_PROMPTS, - EXTRACT_URL_PROMPTS, - ACCURATE_CONTENT_PROMPTS, - PROMPT_MAP, -) - -import logging -import traceback - -logger = logging.getLogger("WebUILogger") - -welcome_message_markdown = """ - # \N{fire} Chatbot with RAG on PAI ! - ### \N{rocket} Build your own personalized knowledge base question-answering chatbot. - - #### \N{fire} Platform: [PAI](https://help.aliyun.com/zh/pai) / [PAI-EAS](https://www.aliyun.com/product/bigdata/learn/eas) / [PAI-DSW](https://pai.console.aliyun.com/notebook)   \N{rocket} Supported VectorStores: [Hologres](https://www.aliyun.com/product/bigdata/hologram) / [ElasticSearch](https://www.aliyun.com/product/bigdata/elasticsearch) / [AnalyticDB](https://www.aliyun.com/product/apsaradb/gpdb) / [FAISS](https://python.langchain.com/docs/integrations/vectorstores/faiss) - - #### \N{fire} API Docs   \N{rocket} \N{fire} 欢迎加入【PAI】RAG答疑群 27370042974 - """ - -css_style = """ - h1, h3, h4 { - text-align: center; - display:block; - } - """ - -DEFAULT_EMBED_SIZE = 1536 -DEFAULT_HF_EMBED_MODEL = "bge-small-zh-v1.5" - -embedding_dim_dict = { - "bge-small-zh-v1.5": 1024, - "SGPT-125M-weightedmean-nli-bitfit": 768, - "text2vec-large-chinese": 1024, - "text2vec-base-chinese": 768, - "paraphrase-multilingual-MiniLM-L12-v2": 384, -} - -embedding_api_key_dict = {"HuggingFace": False, "OpenAI": True, "DashScope": True} - -llm_model_key_dict = { - "DashScope": [ - "qwen-turbo", - "qwen-plus", - "qwen-max", - "qwen-max-1201", - "qwen-max-longcontext", - ], - "OpenAI": [ - "gpt-3.5-turbo", - "gpt-4-turbo", - ], -} - -current_session_id = None - - -def connect_vector_db(input_elements: List[Any]): - try: - update_dict = {} - for element, value in input_elements.items(): - update_dict[element.elem_id] = value - - view_model.update(update_dict) - new_config = view_model.to_app_config() - rag_client.reload_config(new_config) - return f"[{datetime.datetime.now()}] Connect vector db success!" - except Exception as ex: - logger.critical(f"[Critical] Connect failed. {traceback.format_exc()}") - return f"Connect failed. Please check: {ex}" - - -def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction): - view_model.chunk_size = chunk_size - view_model.chunk_overlap = chunk_overlap - new_config = view_model.to_app_config() - rag_client.reload_config(new_config) - - if not upload_files: - return "No file selected. Please choose at least one file." - - for file in upload_files: - file_dir = os.path.dirname(file.name) - rag_client.add_knowledge(file_dir, enable_qa_extraction) - return ( - "Upload " - + str(len(upload_files)) - + " files Success! \n \n Relevant content has been added to the vector store, you can now start chatting and asking questions." - ) - - -def clear_history(chatbot): - chatbot = [] - global current_session_id - current_session_id = None - return chatbot, 0 - - -def respond(input_elements: List[Any]): - global current_session_id - - update_dict = {} - for element, value in input_elements.items(): - update_dict[element.elem_id] = value - - # empty input. - if not update_dict["question"]: - return "", update_dict["chatbot"], 0 - - view_model.update(update_dict) - new_config = view_model.to_app_config() - rag_client.reload_config(new_config) - - query_type = update_dict["query_type"] - msg = update_dict["question"] - chatbot = update_dict["chatbot"] - - if query_type == "LLM": - response = rag_client.query_llm( - msg, - session_id=current_session_id, - ) - - elif query_type == "Retrieval": - response = rag_client.query_vector(msg) - else: - response = rag_client.query(msg, session_id=current_session_id) - if update_dict["include_history"]: - current_session_id = response.session_id - else: - current_session_id = None - chatbot.append((msg, response.answer)) - return "", chatbot, 0 - - -def create_ui(): - with gr.Blocks(css=css_style) as homepage: - gr.Markdown(value=welcome_message_markdown) - - with gr.Tab("\N{rocket} Settings"): - with gr.Row(): - with gr.Column(): - with gr.Column(): - _ = gr.Markdown(value="**Please choose your embedding model.**") - embed_source = gr.Dropdown( - embedding_api_key_dict.keys(), - label="Embedding Type", - value=view_model.embed_source, - elem_id="embed_source", - ) - embed_model = gr.Dropdown( - embedding_dim_dict.keys(), - label="Embedding Model Name", - value=view_model.embed_model, - elem_id="embed_model", - visible=(view_model.embed_source == "HuggingFace"), - ) - embed_dim = gr.Textbox( - label="Embedding Dimension", - value=embedding_dim_dict.get( - view_model.embed_model, DEFAULT_EMBED_SIZE - ), - elem_id="embed_dim", - ) - - def change_emb_source(source): - view_model.embed_source = source - view_model.embed_model = ( - DEFAULT_HF_EMBED_MODEL - if source == "HuggingFace" - else source - ) - _embed_dim = ( - embedding_dim_dict.get( - view_model.embed_model, DEFAULT_EMBED_SIZE - ) - if source == "HuggingFace" - else DEFAULT_EMBED_SIZE - ) - return { - embed_model: gr.update( - visible=(source == "HuggingFace"), - value=view_model.embed_model, - ), - embed_dim: _embed_dim, - } - - def change_emb_model(model): - view_model.embed_model = model - return { - embed_dim: embedding_dim_dict.get( - view_model.embed_model, DEFAULT_EMBED_SIZE - ), - } - - embed_source.change( - fn=change_emb_source, - inputs=embed_source, - outputs=[embed_model, embed_dim], - ) - embed_model.change( - fn=change_emb_model, - inputs=embed_model, - outputs=[embed_dim], - ) - - with gr.Column(): - _ = gr.Markdown(value="**Please set your LLM.**") - llm_src = gr.Dropdown( - ["PaiEas", "OpenAI", "DashScope"], - label="LLM Model Source", - value=view_model.llm, - elem_id="llm", - ) - with gr.Column(visible=(view_model.llm == "PaiEas")) as eas_col: - llm_eas_url = gr.Textbox( - label="EAS Url", - value=view_model.llm_eas_url, - elem_id="llm_eas_url", - ) - llm_eas_token = gr.Textbox( - label="EAS Token", - value=view_model.llm_eas_token, - elem_id="llm_eas_token", - ) - llm_eas_model_name = gr.Textbox( - label="EAS Model name", - value=view_model.llm_eas_model_name, - elem_id="llm_eas_model_name", - ) - with gr.Column( - visible=(view_model.llm != "PaiEas") - ) as api_llm_col: - llm_api_model_name = gr.Dropdown( - llm_model_key_dict.get(view_model.llm, []), - label="LLM Model Name", - value=view_model.llm_api_model_name, - elem_id="llm_api_model_name", - ) - - def change_llm_src(value): - view_model.llm = value - eas_visible = value == "PaiEas" - api_visible = value != "PaiEas" - model_options = llm_model_key_dict.get(value, []) - cur_model = model_options[0] if model_options else "" - return { - eas_col: gr.update(visible=eas_visible), - api_llm_col: gr.update(visible=api_visible), - llm_api_model_name: gr.update( - choices=model_options, value=cur_model - ), - } - - llm_src.change( - fn=change_llm_src, - inputs=llm_src, - outputs=[eas_col, api_llm_col, llm_api_model_name], - ) - """ - with gr.Column(): - _ = gr.Markdown( - value="**(Optional) Please upload your config file.**" - ) - config_file = gr.File( - value=view_model.config_file, - label="Upload a local config json file", - file_types=[".json"], - file_count="single", - interactive=True, - ) - cfg_btn = gr.Button("Parse Config", variant="primary") - """ - create_vector_db_panel( - view_model=view_model, - input_elements={ - llm_src, - llm_eas_url, - llm_eas_token, - llm_eas_model_name, - embed_source, - embed_model, - embed_dim, - llm_api_model_name, - }, - connect_vector_func=connect_vector_db, - ) - - with gr.Tab("\N{whale} Upload"): - with gr.Row(): - with gr.Column(scale=2): - chunk_size = gr.Textbox( - label="\N{rocket} Chunk Size (The size of the chunks into which a document is divided)", - value=view_model.chunk_size, - ) - chunk_overlap = gr.Textbox( - label="\N{fire} Chunk Overlap (The portion of adjacent document chunks that overlap with each other)", - value=view_model.chunk_overlap, - ) - enable_qa_extraction = gr.Checkbox( - label="Yes", - info="Process with QA Extraction Model", - value=view_model.enable_qa_extraction, - elem_id="enable_qa_extraction", - ) - with gr.Column(scale=8): - with gr.Tab("Files"): - upload_file = gr.File( - label="Upload a knowledge file.", file_count="multiple" - ) - upload_file_btn = gr.Button("Upload", variant="primary") - upload_file_state = gr.Textbox(label="Upload State") - with gr.Tab("Directory"): - upload_file_dir = gr.File( - label="Upload a knowledge directory.", - file_count="directory", - ) - upload_dir_btn = gr.Button("Upload", variant="primary") - upload_dir_state = gr.Textbox(label="Upload State") - upload_file_btn.click( - fn=upload_knowledge, - inputs=[ - upload_file, - chunk_size, - chunk_overlap, - enable_qa_extraction, - ], - outputs=upload_file_state, - api_name="upload_knowledge", - ) - upload_dir_btn.click( - fn=upload_knowledge, - inputs=[ - upload_file_dir, - chunk_size, - chunk_overlap, - enable_qa_extraction, - ], - outputs=upload_dir_state, - api_name="upload_knowledge_dir", - ) - - with gr.Tab("\N{fire} Chat"): - with gr.Row(): - with gr.Column(scale=2): - query_type = gr.Radio( - ["Retrieval", "LLM", "RAG (Retrieval + LLM)"], - label="\N{fire} Which query do you want to use?", - elem_id="query_type", - value="RAG (Retrieval + LLM)", - ) - - with gr.Column(visible=True) as vs_col: - vec_model_argument = gr.Accordion( - "Parameters of Vector Retrieval" - ) - - with vec_model_argument: - similarity_top_k = gr.Slider( - minimum=0, - maximum=100, - step=1, - elem_id="similarity_top_k", - value=view_model.similarity_top_k, - label="Top K (choose between 0 and 100)", - ) - # similarity_cutoff = gr.Slider(minimum=0, maximum=1, step=0.01,elem_id="similarity_cutoff",value=view_model.similarity_cutoff, label="Similarity Distance Threshold (The more similar the vectors, the smaller the value.)") - rerank_model = gr.Radio( - [ - "no-reranker", - "bge-reranker-base", - "bge-reranker-large", - "llm-reranker", - ], - label="Re-Rank Model (Note: It will take a long time to load the model when using it for the first time.)", - elem_id="rerank_model", - value=view_model.rerank_model, - ) - retrieval_mode = gr.Radio( - ["Embedding Only", "Keyword Ensembled", "Keyword Only"], - label="Retrieval Mode", - elem_id="retrieval_mode", - value=view_model.retrieval_mode, - ) - vec_args = { - similarity_top_k, - # similarity_cutoff, - rerank_model, - retrieval_mode, - } - with gr.Column(visible=True) as llm_col: - model_argument = gr.Accordion("Inference Parameters of LLM") - with model_argument: - include_history = gr.Checkbox( - label="Chat history", - info="Query with chat history.", - elem_id="include_history", - ) - llm_topk = gr.Slider( - minimum=0, - maximum=100, - step=1, - value=30, - elem_id="llm_topk", - label="Top K (choose between 0 and 100)", - ) - llm_topp = gr.Slider( - minimum=0, - maximum=1, - step=0.01, - value=0.8, - elem_id="llm_topp", - label="Top P (choose between 0 and 1)", - ) - llm_temp = gr.Slider( - minimum=0, - maximum=1, - step=0.01, - value=0.7, - elem_id="llm_temp", - label="Temperature (choose between 0 and 1)", - ) - llm_args = {llm_topk, llm_topp, llm_temp, include_history} - - with gr.Column(visible=True) as lc_col: - prm_type = PROMPT_MAP.get(view_model.text_qa_template, "Custom") - prm_radio = gr.Radio( - [ - "Simple", - "General", - "Extract URL", - "Accurate Content", - "Custom", - ], - label="\N{rocket} Please choose the prompt template type", - value=prm_type, - ) - text_qa_template = gr.Textbox( - label="prompt template", - placeholder=view_model.text_qa_template, - value=view_model.text_qa_template, - elem_id="text_qa_template", - lines=4, - ) - - def change_prompt_template(prm_radio): - if prm_radio == "Simple": - return { - text_qa_template: gr.update( - value=SIMPLE_PROMPTS, interactive=False - ) - } - elif prm_radio == "General": - return { - text_qa_template: gr.update( - value=GENERAL_PROMPTS, interactive=False - ) - } - elif prm_radio == "Extract URL": - return { - text_qa_template: gr.update( - value=EXTRACT_URL_PROMPTS, interactive=False - ) - } - elif prm_radio == "Accurate Content": - return { - text_qa_template: gr.update( - value=ACCURATE_CONTENT_PROMPTS, - interactive=False, - ) - } - else: - return { - text_qa_template: gr.update( - value="", interactive=True - ) - } - - prm_radio.change( - fn=change_prompt_template, - inputs=prm_radio, - outputs=[text_qa_template], - ) - - cur_tokens = gr.Textbox( - label="\N{fire} Current total count of tokens", visible=False - ) - - def change_query_radio(query_type): - global current_session_id - current_session_id = None - if query_type == "Retrieval": - return { - vs_col: gr.update(visible=True), - llm_col: gr.update(visible=False), - lc_col: gr.update(visible=False), - } - elif query_type == "LLM": - return { - vs_col: gr.update(visible=False), - llm_col: gr.update(visible=True), - lc_col: gr.update(visible=False), - } - elif query_type == "RAG (Retrieval + LLM)": - return { - vs_col: gr.update(visible=True), - llm_col: gr.update(visible=True), - lc_col: gr.update(visible=True), - } - - query_type.change( - fn=change_query_radio, - inputs=query_type, - outputs=[vs_col, llm_col, lc_col], - ) - - with gr.Column(scale=8): - chatbot = gr.Chatbot(height=500, elem_id="chatbot") - question = gr.Textbox( - label="Enter your question.", elem_id="question" - ) - with gr.Row(): - submitBtn = gr.Button("Submit", variant="primary") - clearBtn = gr.Button("Clear History", variant="secondary") - - chat_args = ( - {text_qa_template, question, query_type, chatbot} - .union(vec_args) - .union(llm_args) - ) - - submitBtn.click( - respond, - chat_args, - [question, chatbot, cur_tokens], - api_name="respond", - ) - clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens]) - - return homepage diff --git a/src/pai_rag/app/web/ui_constants.py b/src/pai_rag/app/web/ui_constants.py new file mode 100644 index 00000000..69946f6e --- /dev/null +++ b/src/pai_rag/app/web/ui_constants.py @@ -0,0 +1,55 @@ +SIMPLE_PROMPTS = "参考内容如下:\n{context_str}\n作为个人知识答疑助手,请根据上述参考内容回答下面问题,答案中不允许包含编造内容。\n用户问题:\n{query_str}" +GENERAL_PROMPTS = '基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n=====\n已知信息:\n{context_str}\n=====\n用户问题:\n{query_str}' +EXTRACT_URL_PROMPTS = "你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}" +ACCURATE_CONTENT_PROMPTS = "你是一位知识小助手,请根据下面我提供的知识库中相关知识,对我提出的若干问题进行回答,同时回答的内容需满足我所提的要求! \n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}" + +PROMPT_MAP = { + SIMPLE_PROMPTS: "Simple", + GENERAL_PROMPTS: "General", + EXTRACT_URL_PROMPTS: "Extract URL", + ACCURATE_CONTENT_PROMPTS: "Accurate Content", +} + +WELCOME_MESSAGE = """ + # \N{fire} Chatbot with RAG on PAI ! + ### \N{rocket} Build your own personalized knowledge base question-answering chatbot. + + #### \N{fire} Platform: [PAI](https://help.aliyun.com/zh/pai) / [PAI-EAS](https://www.aliyun.com/product/bigdata/learn/eas) / [PAI-DSW](https://pai.console.aliyun.com/notebook)   \N{rocket} Supported VectorStores: [Hologres](https://www.aliyun.com/product/bigdata/hologram) / [ElasticSearch](https://www.aliyun.com/product/bigdata/elasticsearch) / [AnalyticDB](https://www.aliyun.com/product/apsaradb/gpdb) / [FAISS](https://python.langchain.com/docs/integrations/vectorstores/faiss) + + #### \N{fire} API Docs   \N{rocket} \N{fire} 欢迎加入【PAI】RAG答疑群 27370042974 + """ + +DEFAULT_CSS_STYPE = """ + h1, h3, h4 { + text-align: center; + display:block; + } + """ + +DEFAULT_EMBED_SIZE = 1536 + +EMBEDDING_DIM_DICT = { + "bge-small-zh-v1.5": 1024, + "SGPT-125M-weightedmean-nli-bitfit": 768, + "text2vec-large-chinese": 1024, + "text2vec-base-chinese": 768, + "paraphrase-multilingual-MiniLM-L12-v2": 384, +} + +DEFAULT_HF_EMBED_MODEL = "bge-small-zh-v1.5" + +EMBEDDING_API_KEY_DICT = {"HuggingFace": False, "OpenAI": True, "DashScope": True} + +LLM_MODEL_KEY_DICT = { + "DashScope": [ + "qwen-turbo", + "qwen-plus", + "qwen-max", + "qwen-max-1201", + "qwen-max-longcontext", + ], + "OpenAI": [ + "gpt-3.5-turbo", + "gpt-4-turbo", + ], +} diff --git a/src/pai_rag/app/web/utils.py b/src/pai_rag/app/web/utils.py new file mode 100644 index 00000000..86bbe30c --- /dev/null +++ b/src/pai_rag/app/web/utils.py @@ -0,0 +1,5 @@ +from typing import List, Any, Dict + + +def components_to_dict(components: List[Any]) -> Dict[str, Any]: + return {c.elem_id: c for c in components} diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index d095e304..598a1b1b 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -1,6 +1,13 @@ from pydantic import BaseModel from typing import Any, Dict from collections import defaultdict +from pai_rag.app.web.ui_constants import ( + EMBEDDING_DIM_DICT, + DEFAULT_EMBED_SIZE, + DEFAULT_HF_EMBED_MODEL, + LLM_MODEL_KEY_DICT, + PROMPT_MAP, +) def recursive_dict(): @@ -17,7 +24,7 @@ def _transform_to_dict(config): class ViewModel(BaseModel): # embedding embed_source: str = "HuggingFace" - embed_model: str = "bge-small-zh-v1.5" + embed_model: str = DEFAULT_HF_EMBED_MODEL embed_dim: int = 1024 embed_api_key: str = None @@ -299,5 +306,77 @@ def to_app_config(self): return _transform_to_dict(config) + def to_component_settings(self) -> Dict[str, Dict[str, Any]]: + settings = {} + settings["embed_source"] = {"value": self.embed_source} + settings["embed_model"] = { + "value": self.embed_model, + "visible": self.embed_source == "HuggingFace", + } + settings["embed_dim"] = { + "value": EMBEDDING_DIM_DICT.get(self.embed_model, DEFAULT_EMBED_SIZE) + if self.embed_source == "HuggingFace" + else DEFAULT_EMBED_SIZE + } + + settings["llm"] = {"value": self.llm} + settings["llm_eas_url"] = {"value": self.llm_eas_url} + settings["llm_eas_token"] = {"value": self.llm_eas_token} + settings["llm_eas_model_name"] = {"value": self.llm_eas_model_name} + settings["llm_api_model_name"] = { + "value": self.llm_api_model_name, + "choices": LLM_MODEL_KEY_DICT.get(self.llm, []), + } + settings["chunk_size"] = {"value": self.chunk_size} + settings["chunk_overlap"] = {"value": self.chunk_overlap} + settings["enable_qa_extraction"] = {"value": self.enable_qa_extraction} + settings["similarity_top_k"] = {"value": self.similarity_top_k} + settings["rerank_model"] = {"value": self.rerank_model} + settings["retrieval_mode"] = {"value": self.retrieval_mode} + + prm_type = PROMPT_MAP.get(view_model.text_qa_template, "Custom") + settings["prm_type"] = {"value": prm_type} + settings["text_qa_template"] = {"value": self.text_qa_template} + + settings["vectordb_type"] = {"value": self.vectordb_type} + + # adb + settings["adb_ak"] = {"value": self.adb_ak} + settings["adb_sk"] = {"value": self.adb_sk} + settings["adb_region_id"] = {"value": self.adb_region_id} + settings["adb_account"] = {"value": self.adb_account} + settings["adb_account_password"] = {"value": self.adb_account_password} + settings["adb_namespace"] = {"value": self.adb_namespace} + settings["adb_instance_id"] = {"value": self.adb_instance_id} + settings["adb_collection"] = {"value": self.adb_collection} + + # hologres + settings["hologres_host"] = {"value": self.hologres_host} + settings["hologres_database"] = {"value": self.hologres_database} + settings["hologres_port"] = {"value": self.hologres_port} + settings["hologres_user"] = {"value": self.hologres_user} + settings["hologres_password"] = {"value": self.hologres_password} + settings["hologres_table"] = {"value": self.hologres_table} + settings["hologres_pre_delete"] = {"value": self.hologres_pre_delete} + + # elasticsearch + settings["es_url"] = {"value": self.es_url} + settings["es_index"] = {"value": self.es_index} + settings["es_user"] = {"value": self.es_user} + settings["es_password"] = {"value": self.es_password} + + # milvus + settings["milvus_host"] = {"value": self.milvus_host} + settings["milvus_port"] = {"value": self.milvus_port} + settings["milvus_database"] = {"value": self.milvus_database} + settings["milvus_user"] = {"value": self.milvus_user} + settings["milvus_password"] = {"value": self.milvus_password} + settings["milvus_collection_name"] = {"value": self.milvus_collection_name} + + # faiss + settings["faiss_path"] = {"value": self.faiss_path} + + return settings + view_model = ViewModel() diff --git a/src/pai_rag/app/web/webui.py b/src/pai_rag/app/web/webui.py new file mode 100644 index 00000000..c920968a --- /dev/null +++ b/src/pai_rag/app/web/webui.py @@ -0,0 +1,48 @@ +import gradio as gr +from pai_rag.app.web.view_model import view_model +from pai_rag.app.web.tabs.settings_tab import create_setting_tab +from pai_rag.app.web.tabs.upload_tab import create_upload_tab +from pai_rag.app.web.tabs.chat_tab import create_chat_tab +from pai_rag.app.web.element_manager import elem_manager +from pai_rag.app.web.ui_constants import ( + DEFAULT_CSS_STYPE, + WELCOME_MESSAGE, +) + +import logging + +logger = logging.getLogger("WebUILogger") + + +def resume_ui(): + outputs = {} + component_settings = view_model.to_component_settings() + + for elem in elem_manager.get_elem_list(): + elem_id = elem.elem_id + elem_attr = component_settings[elem_id] + elem = elem_manager.get_elem_by_id(elem_id=elem_id) + + # For gradio version 3.41.0, we can remove .value for latest gradio here. + outputs[elem] = elem.__class__(**elem_attr).value + + return outputs + + +def create_ui(): + with gr.Blocks(css=DEFAULT_CSS_STYPE) as homepage: + # generate components + gr.Markdown(value=WELCOME_MESSAGE) + with gr.Tab("\N{rocket} Settings"): + setting_elements = create_setting_tab() + elem_manager.add_elems(setting_elements) + with gr.Tab("\N{whale} Upload"): + upload_elements = create_upload_tab() + elem_manager.add_elems(upload_elements) + with gr.Tab("\N{fire} Chat"): + chat_elements = create_chat_tab() + elem_manager.add_elems(chat_elements) + homepage.load( + resume_ui, outputs=elem_manager.get_elem_list(), concurrency_limit=None + ) + return homepage