Skip to content

Commit

Permalink
feat: enable Gradio page to share state between tabs (#48)
Browse files Browse the repository at this point in the history
* Remove entrypoint

* Update ui

* Revert dockerfile change

* Delete debug clauses

* Fix embed dim value

* Update view model
  • Loading branch information
moria97 authored Jun 4, 2024
1 parent 7d2aecc commit 1705a9f
Show file tree
Hide file tree
Showing 12 changed files with 803 additions and 615 deletions.
3 changes: 1 addition & 2 deletions src/pai_rag/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api import query
from pai_rag.app.api.middleware import init_middleware
from pai_rag.app.web.ui import create_ui
from pai_rag.app.web.webui import create_ui
from pai_rag.app.web.rag_client import rag_client

# UI_PATH = "/ui"
UI_PATH = ""


Expand Down
48 changes: 48 additions & 0 deletions src/pai_rag/app/web/element_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING, Dict, Generator, List, Tuple

if TYPE_CHECKING:
from gradio.components import Component


class ElementManager:
def __init__(self) -> None:
self._id_to_elem: Dict[str, "Component"] = {}
self._elem_to_id: Dict["Component", str] = {}

def add_elems(self, elem_dict: Dict[str, "Component"]) -> None:
r"""
Adds elements to manager.
"""
for elem_id, elem in elem_dict.items():
self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id

def get_elem_list(self) -> List["Component"]:
r"""
Returns the list of all elements.
"""
return list(self._id_to_elem.values())

def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
r"""
Returns an iterator over all elements with their names.
"""
for elem_id, elem in self._id_to_elem.items():
yield elem_id.split(".")[-1], elem

def get_elem_by_id(self, elem_id: str) -> "Component":
r"""
Gets element by id.
Example: top.lang, train.dataset
"""
return self._id_to_elem[elem_id]

def get_id_by_elem(self, elem: "Component") -> str:
r"""
Gets id by element.
"""
return self._elem_to_id[elem]


elem_manager = ElementManager()
11 changes: 0 additions & 11 deletions src/pai_rag/app/web/prompts.py

This file was deleted.

249 changes: 249 additions & 0 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from typing import Dict, Any, List
import gradio as gr
from pai_rag.app.web.rag_client import rag_client
from pai_rag.app.web.view_model import view_model
from pai_rag.app.web.ui_constants import (
SIMPLE_PROMPTS,
GENERAL_PROMPTS,
EXTRACT_URL_PROMPTS,
ACCURATE_CONTENT_PROMPTS,
)


current_session_id = None


def clear_history(chatbot):
chatbot = []
global current_session_id
current_session_id = None
return chatbot, 0


def respond(input_elements: List[Any]):
global current_session_id

update_dict = {}
for element, value in input_elements.items():
update_dict[element.elem_id] = value

# empty input.
if not update_dict["question"]:
return "", update_dict["chatbot"], 0

view_model.update(update_dict)
new_config = view_model.to_app_config()
rag_client.reload_config(new_config)

query_type = update_dict["query_type"]
msg = update_dict["question"]
chatbot = update_dict["chatbot"]

if query_type == "LLM":
response = rag_client.query_llm(
msg,
session_id=current_session_id,
)

elif query_type == "Retrieval":
response = rag_client.query_vector(msg)
else:
response = rag_client.query(msg, session_id=current_session_id)
if update_dict["include_history"]:
current_session_id = response.session_id
else:
current_session_id = None
chatbot.append((msg, response.answer))
return "", chatbot, 0


def create_chat_tab() -> Dict[str, Any]:
with gr.Row():
with gr.Column(scale=2):
query_type = gr.Radio(
["Retrieval", "LLM", "RAG (Retrieval + LLM)"],
label="\N{fire} Which query do you want to use?",
elem_id="query_type",
value="RAG (Retrieval + LLM)",
)

with gr.Column(visible=True) as vs_col:
vec_model_argument = gr.Accordion("Parameters of Vector Retrieval")

with vec_model_argument:
similarity_top_k = gr.Slider(
minimum=0,
maximum=100,
step=1,
elem_id="similarity_top_k",
label="Top K (choose between 0 and 100)",
)
# similarity_cutoff = gr.Slider(minimum=0, maximum=1, step=0.01,elem_id="similarity_cutoff",value=view_model.similarity_cutoff, label="Similarity Distance Threshold (The more similar the vectors, the smaller the value.)")
rerank_model = gr.Radio(
[
"no-reranker",
"bge-reranker-base",
"bge-reranker-large",
"llm-reranker",
],
label="Re-Rank Model (Note: It will take a long time to load the model when using it for the first time.)",
elem_id="rerank_model",
)
retrieval_mode = gr.Radio(
["Embedding Only", "Keyword Ensembled", "Keyword Only"],
label="Retrieval Mode",
elem_id="retrieval_mode",
)
vec_args = {
similarity_top_k,
# similarity_cutoff,
rerank_model,
retrieval_mode,
}
with gr.Column(visible=True) as llm_col:
model_argument = gr.Accordion("Inference Parameters of LLM")
with model_argument:
include_history = gr.Checkbox(
label="Chat history",
info="Query with chat history.",
elem_id="include_history",
)
llm_topk = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=30,
elem_id="llm_topk",
label="Top K (choose between 0 and 100)",
)
llm_topp = gr.Slider(
minimum=0,
maximum=1,
step=0.01,
value=0.8,
elem_id="llm_topp",
label="Top P (choose between 0 and 1)",
)
llm_temp = gr.Slider(
minimum=0,
maximum=1,
step=0.01,
value=0.7,
elem_id="llm_temp",
label="Temperature (choose between 0 and 1)",
)
llm_args = {llm_topk, llm_topp, llm_temp, include_history}

with gr.Column(visible=True) as lc_col:
prm_type = gr.Radio(
[
"Simple",
"General",
"Extract URL",
"Accurate Content",
"Custom",
],
label="\N{rocket} Please choose the prompt template type",
elem_id="prm_type",
)
text_qa_template = gr.Textbox(
label="prompt template",
placeholder="",
elem_id="text_qa_template",
lines=4,
)

def change_prompt_template(prm_type):
if prm_type == "Simple":
return {
text_qa_template: gr.update(
value=SIMPLE_PROMPTS, interactive=False
)
}
elif prm_type == "General":
return {
text_qa_template: gr.update(
value=GENERAL_PROMPTS, interactive=False
)
}
elif prm_type == "Extract URL":
return {
text_qa_template: gr.update(
value=EXTRACT_URL_PROMPTS, interactive=False
)
}
elif prm_type == "Accurate Content":
return {
text_qa_template: gr.update(
value=ACCURATE_CONTENT_PROMPTS,
interactive=False,
)
}
else:
return {text_qa_template: gr.update(value="", interactive=True)}

prm_type.change(
fn=change_prompt_template,
inputs=prm_type,
outputs=[text_qa_template],
)

cur_tokens = gr.Textbox(
label="\N{fire} Current total count of tokens", visible=False
)

def change_query_radio(query_type):
global current_session_id
current_session_id = None
if query_type == "Retrieval":
return {
vs_col: gr.update(visible=True),
llm_col: gr.update(visible=False),
lc_col: gr.update(visible=False),
}
elif query_type == "LLM":
return {
vs_col: gr.update(visible=False),
llm_col: gr.update(visible=True),
lc_col: gr.update(visible=False),
}
elif query_type == "RAG (Retrieval + LLM)":
return {
vs_col: gr.update(visible=True),
llm_col: gr.update(visible=True),
lc_col: gr.update(visible=True),
}

query_type.change(
fn=change_query_radio,
inputs=query_type,
outputs=[vs_col, llm_col, lc_col],
)

with gr.Column(scale=8):
chatbot = gr.Chatbot(height=500, elem_id="chatbot")
question = gr.Textbox(label="Enter your question.", elem_id="question")
with gr.Row():
submitBtn = gr.Button("Submit", variant="primary")
clearBtn = gr.Button("Clear History", variant="secondary")

chat_args = (
{text_qa_template, question, query_type, chatbot}
.union(vec_args)
.union(llm_args)
)

submitBtn.click(
respond,
chat_args,
[question, chatbot, cur_tokens],
api_name="respond",
)
clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens])
return {
similarity_top_k.elem_id: similarity_top_k,
rerank_model.elem_id: rerank_model,
retrieval_mode.elem_id: retrieval_mode,
prm_type.elem_id: prm_type,
text_qa_template.elem_id: text_qa_template,
}
Loading

0 comments on commit 1705a9f

Please sign in to comment.