Skip to content

Commit

Permalink
support llm infer param: temperature (#52)
Browse files Browse the repository at this point in the history
* support llm infer param: temperature

* modify style of retrieval chunks

* support safe_html_content
  • Loading branch information
wwxxzz authored Jun 6, 2024
1 parent 115a695 commit 4bb4ce8
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 237 deletions.
216 changes: 24 additions & 192 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ llama-index-llms-huggingface = "^0.2.0"
pytest-asyncio = "^0.23.7"
pytest-cov = "^5.0.0"
xlrd = "^2.0.1"
markdown = "^3.6"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
8 changes: 2 additions & 6 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@

class RagQuery(BaseModel):
question: str
topk: int | None = 3
topp: float | None = 0.8
temperature: float | None = 0.7
temperature: float | None = 0.1
vector_topk: int | None = 3
score_threshold: float | None = 0.5
chat_history: List[Dict[str, str]] | None = None


class LlmQuery(BaseModel):
question: str
topk: int | None = 3
topp: float | None = 0.8
temperature: float | None = 0.7
temperature: float | None = 0.1
chat_history: List[Dict[str, str]] | None = None


Expand Down
22 changes: 12 additions & 10 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any
import requests
import html
import markdown

cache_config = None

Expand Down Expand Up @@ -55,11 +57,9 @@ def query_llm(
self,
text: str,
session_id: str = None,
temperature: float = 0.7,
top_p: float = 0.8,
eas_llm_top_k: float = 30,
temperature: float = 0.1,
):
q = dict(question=text, topp=top_p, topk=eas_llm_top_k, temperature=temperature)
q = dict(question=text, temperature=temperature)

r = requests.post(self.llm_url, headers={"X-Session-ID": session_id}, json=q)
r.raise_for_status()
Expand All @@ -76,12 +76,14 @@ def query_vector(self, text: str):
session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
response.session_id = session_id
formatted_text = "\n\n".join(
[
f"""[Doc {i+1}] [score: {doc["score"]}]\n{doc["text"]}"""
for i, doc in enumerate(response["docs"])
]
)
formatted_text = "<tr><th>Document</th><th>Score</th><th>Text</th></tr>\n"
for i, doc in enumerate(response["docs"]):
html_content = markdown.markdown(doc["text"])
safe_html_content = html.escape(html_content).replace("\n", "<br>")
formatted_text += '<tr style="font-size: 13px;"><td>Doc {}</td><td>{}</td><td>{}</td></tr>\n'.format(
i + 1, doc["score"], safe_html_content
)
formatted_text = "<table>\n<tbody>\n" + formatted_text + "</tbody>\n</table>"
response["answer"] = formatted_text
return response

Expand Down
44 changes: 21 additions & 23 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def create_chat_tab() -> Dict[str, Any]:
)

with gr.Column(visible=True) as vs_col:
vec_model_argument = gr.Accordion("Parameters of Vector Retrieval")
vec_model_argument = gr.Accordion(
"Parameters of Vector Retrieval", open=False
)

with vec_model_argument:
similarity_top_k = gr.Slider(
Expand Down Expand Up @@ -101,38 +103,22 @@ def create_chat_tab() -> Dict[str, Any]:
retrieval_mode,
}
with gr.Column(visible=True) as llm_col:
model_argument = gr.Accordion("Inference Parameters of LLM")
model_argument = gr.Accordion("Inference Parameters of LLM", open=False)
with model_argument:
include_history = gr.Checkbox(
label="Chat history",
info="Query with chat history.",
elem_id="include_history",
)
llm_topk = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=30,
elem_id="llm_topk",
label="Top K (choose between 0 and 100)",
)
llm_topp = gr.Slider(
minimum=0,
maximum=1,
step=0.01,
value=0.8,
elem_id="llm_topp",
label="Top P (choose between 0 and 1)",
)
llm_temp = gr.Slider(
minimum=0,
maximum=1,
step=0.01,
value=0.7,
elem_id="llm_temp",
step=0.001,
value=0.1,
elem_id="llm_temperature",
label="Temperature (choose between 0 and 1)",
)
llm_args = {llm_topk, llm_topp, llm_temp, include_history}
llm_args = {llm_temp, include_history}

with gr.Column(visible=True) as lc_col:
prm_type = gr.Radio(
Expand Down Expand Up @@ -198,26 +184,32 @@ def change_query_radio(query_type):
if query_type == "Retrieval":
return {
vs_col: gr.update(visible=True),
vec_model_argument: gr.update(open=True),
llm_col: gr.update(visible=False),
model_argument: gr.update(open=False),
lc_col: gr.update(visible=False),
}
elif query_type == "LLM":
return {
vs_col: gr.update(visible=False),
vec_model_argument: gr.update(open=False),
llm_col: gr.update(visible=True),
model_argument: gr.update(open=True),
lc_col: gr.update(visible=False),
}
elif query_type == "RAG (Retrieval + LLM)":
return {
vs_col: gr.update(visible=True),
vec_model_argument: gr.update(open=False),
llm_col: gr.update(visible=True),
model_argument: gr.update(open=False),
lc_col: gr.update(visible=True),
}

query_type.change(
fn=change_query_radio,
inputs=query_type,
outputs=[vs_col, llm_col, lc_col],
outputs=[vs_col, vec_model_argument, llm_col, model_argument, lc_col],
)

with gr.Column(scale=8):
Expand All @@ -239,6 +231,12 @@ def change_query_radio(query_type):
[question, chatbot, cur_tokens],
api_name="respond",
)
question.submit(
respond,
chat_args,
[question, chatbot, cur_tokens],
api_name="respond",
)
clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens])
return {
similarity_top_k.elem_id: similarity_top_k,
Expand Down
3 changes: 3 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ViewModel(BaseModel):
llm_eas_model_name: str = None
llm_api_key: str = None
llm_api_model_name: str = None
llm_temperature: float = 0.1

# chunking
parser_type: str = "Sentence"
Expand Down Expand Up @@ -115,6 +116,7 @@ def sync_app_config(self, config):
self.llm_eas_url = config["llm"].get("endpoint", self.llm_eas_url)
self.llm_eas_token = config["llm"].get("token", self.llm_eas_token)
self.llm_api_key = config["llm"].get("api_key", self.llm_api_key)
self.llm_temperature = config["llm"].get("temperature", self.llm_temperature)
if self.llm == "PaiEAS":
self.llm_eas_model_name = config["llm"].get("name", self.llm_eas_model_name)
else:
Expand Down Expand Up @@ -217,6 +219,7 @@ def to_app_config(self):
config["llm"]["endpoint"] = self.llm_eas_url
config["llm"]["token"] = self.llm_eas_token
config["llm"]["api_key"] = self.llm_api_key
config["llm"]["temperature"] = self.llm_temperature
if self.llm == "PaiEas":
config["llm"]["name"] = self.llm_eas_model_name
else:
Expand Down
19 changes: 13 additions & 6 deletions src/pai_rag/modules/llm/llm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
f"""
[Parameters][LLM:OpenAI]
model = {config.get("name", "gpt-3.5-turbo")},
temperature = {config.get("temperature", 0.5)},
temperature = {config.get("temperature", 0.1)},
system_prompt = {config.get("system_prompt", "Please answer in Chinese.")}
"""
)
llm = OpenAI(
model=config.get("name", "gpt-3.5-turbo"),
temperature=config.get("temperature", 0.5),
temperature=config.get("temperature", 0.1),
system_prompt=config.get("system_prompt", "Please answer in Chinese."),
api_key=config.get("api_key", None),
)
Expand All @@ -39,13 +39,13 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
f"""
[Parameters][LLM:AzureOpenAI]
model = {config.get("name", "gpt-35-turbo")},
temperature = {config.get("temperature", 0.5)},
temperature = {config.get("temperature", 0.1)},
system_prompt = {config.get("system_prompt", "Please answer in Chinese.")}
"""
)
llm = AzureOpenAI(
model=config.get("name", "gpt-35-turbo"),
temperature=config.get("temperature", 0.5),
temperature=config.get("temperature", 0.1),
system_prompt=config.get("system_prompt", "Please answer in Chinese."),
)
elif source == "dashscope":
Expand All @@ -56,7 +56,9 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
model = {model_name}
"""
)
llm = DashScope(model_name=model_name)
llm = DashScope(
model_name=model_name, temperature=config.get("temperature", 0.1)
)
elif source == "paieas":
model_name = config["name"]
endpoint = config["endpoint"]
Expand All @@ -69,7 +71,12 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
token = {token}
"""
)
llm = PaiEAS(endpoint=endpoint, token=token, model_name=model_name)
llm = PaiEAS(
endpoint=endpoint,
token=token,
model_name=model_name,
temperature=config.get("temperature", 0.1),
)
else:
raise ValueError(f"Unknown LLM source: '{config['llm']['source']}'")

Expand Down

0 comments on commit 4bb4ce8

Please sign in to comment.