diff --git a/src/pai_rag/app/api/middleware.py b/src/pai_rag/app/api/middleware.py
index 06d42009..13a29a85 100644
--- a/src/pai_rag/app/api/middleware.py
+++ b/src/pai_rag/app/api/middleware.py
@@ -31,7 +31,7 @@ def init_middleware(app: FastAPI):
def _configure_session_middleware(app):
app.add_middleware(
CorrelationIdMiddleware,
- header_name="X-Session-ID",
+ header_name="X-Request-ID",
)
diff --git a/src/pai_rag/app/api/models.py b/src/pai_rag/app/api/models.py
index 6cbe25c6..8cfede1e 100644
--- a/src/pai_rag/app/api/models.py
+++ b/src/pai_rag/app/api/models.py
@@ -8,12 +8,14 @@ class RagQuery(BaseModel):
vector_topk: int | None = 3
score_threshold: float | None = 0.5
chat_history: List[Dict[str, str]] | None = None
+ session_id: str | None = None
class LlmQuery(BaseModel):
question: str
temperature: float | None = 0.1
chat_history: List[Dict[str, str]] | None = None
+ session_id: str | None = None
class RetrievalQuery(BaseModel):
@@ -24,12 +26,14 @@ class RetrievalQuery(BaseModel):
class RagResponse(BaseModel):
answer: str
+ session_id: str | None = None
# TODO
# context: List[str] | None = None
class LlmResponse(BaseModel):
answer: str
+ session_id: str | None = None
class ContextDoc(BaseModel):
diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py
index 7d034377..35e88508 100644
--- a/src/pai_rag/app/web/rag_client.py
+++ b/src/pai_rag/app/web/rag_client.py
@@ -44,13 +44,10 @@ def load_data_url(self):
return f"{self.endpoint}service/data"
def query(self, text: str, session_id: str = None):
- q = dict(question=text)
- r = requests.post(self.query_url, headers={"X-Session-ID": session_id}, json=q)
+ q = dict(question=text, session_id=session_id)
+ r = requests.post(self.query_url, json=q)
r.raise_for_status()
- session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
- response.session_id = session_id
-
return response
def query_llm(
@@ -59,13 +56,15 @@ def query_llm(
session_id: str = None,
temperature: float = 0.1,
):
- q = dict(question=text, temperature=temperature)
+ q = dict(
+ question=text,
+ temperature=temperature,
+ session_id=session_id,
+ )
- r = requests.post(self.llm_url, headers={"X-Session-ID": session_id}, json=q)
+ r = requests.post(self.llm_url, json=q)
r.raise_for_status()
- session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
- response.session_id = session_id
return response
@@ -73,9 +72,7 @@ def query_vector(self, text: str):
q = dict(question=text)
r = requests.post(self.retrieval_url, json=q)
r.raise_for_status()
- session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
- response.session_id = session_id
formatted_text = "
Document | Score | Text |
\n"
for i, doc in enumerate(response["docs"]):
html_content = markdown.markdown(doc["text"])
diff --git a/src/pai_rag/app/web/tabs/chat_tab.py b/src/pai_rag/app/web/tabs/chat_tab.py
index c6eee69d..8bb7998b 100644
--- a/src/pai_rag/app/web/tabs/chat_tab.py
+++ b/src/pai_rag/app/web/tabs/chat_tab.py
@@ -39,20 +39,20 @@ def respond(input_elements: List[Any]):
msg = update_dict["question"]
chatbot = update_dict["chatbot"]
+ if not update_dict["include_history"]:
+ current_session_id = None
+
if query_type == "LLM":
response = rag_client.query_llm(
msg,
session_id=current_session_id,
)
-
+ current_session_id = response.session_id
elif query_type == "Retrieval":
response = rag_client.query_vector(msg)
else:
response = rag_client.query(msg, session_id=current_session_id)
- if update_dict["include_history"]:
current_session_id = response.session_id
- else:
- current_session_id = None
chatbot.append((msg, response.answer))
return "", chatbot, 0
diff --git a/src/pai_rag/app/web/tabs/vector_db_panel.py b/src/pai_rag/app/web/tabs/vector_db_panel.py
index fc66df8a..ea3e175b 100644
--- a/src/pai_rag/app/web/tabs/vector_db_panel.py
+++ b/src/pai_rag/app/web/tabs/vector_db_panel.py
@@ -11,9 +11,7 @@ def create_vector_db_panel(
components = []
with gr.Column():
with gr.Column():
- _ = gr.Markdown(
- value=f"**Please check your Vector Store for {view_model.vectordb_type}.**"
- )
+ _ = gr.Markdown(value="**Please check your Vector Store.**")
vectordb_type = gr.Dropdown(
["Hologres", "Milvus", "ElasticSearch", "AnalyticDB", "FAISS"],
label="Which VectorStore do you want to use?",
diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py
index 17f4d6e4..8a842ca6 100644
--- a/src/pai_rag/core/rag_application.py
+++ b/src/pai_rag/core/rag_application.py
@@ -1,4 +1,3 @@
-from asgi_correlation_id import correlation_id
from pai_rag.data.rag_dataloader import RagDataLoader
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.module_registry import module_registry
@@ -15,8 +14,11 @@
from llama_index.core.schema import QueryBundle
import logging
+from uuid import uuid4
-DEFAULT_SESSION_ID = "default" # For test-only
+
+def uuid_generator() -> str:
+ return uuid4().hex
class RagApplication:
@@ -63,8 +65,6 @@ async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RetrievalResponse(docs=[])
- session_id = correlation_id.get() or DEFAULT_SESSION_ID
- self.logger.info(f"Get session ID: {session_id}.")
query_bundle = QueryBundle(query.question)
node_results = await self.query_engine.aretrieve(query_bundle)
@@ -89,17 +89,19 @@ async def aquery(self, query: RagQuery) -> RagResponse:
Returns:
RagResponse
"""
+ session_id = query.session_id or uuid_generator()
+ self.logger.info(f"Get session ID: {session_id}.")
if not query.question:
- return RagResponse(answer="Empty query. Please input your question.")
+ return RagResponse(
+ answer="Empty query. Please input your question.", session_id=session_id
+ )
- session_id = correlation_id.get() or DEFAULT_SESSION_ID
- self.logger.info(f"Get session ID: {session_id}.")
query_chat_engine = self.chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await query_chat_engine.achat(query.question)
self.chat_store.persist()
- return RagResponse(answer=response.response)
+ return RagResponse(answer=response.response, session_id=session_id)
async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
"""Query answer from LLM response asynchronously.
@@ -112,17 +114,20 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
Returns:
LlmResponse
"""
+ session_id = query.session_id or uuid_generator()
+ self.logger.info(f"Get session ID: {session_id}.")
+
if not query.question:
- return LlmResponse(answer="Empty query. Please input your question.")
+ return LlmResponse(
+ answer="Empty query. Please input your question.", session_id=session_id
+ )
- session_id = correlation_id.get() or DEFAULT_SESSION_ID
- self.logger.info(f"Get session ID: {session_id}.")
llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await llm_chat_engine.achat(query.question)
self.chat_store.persist()
- return LlmResponse(answer=response.response)
+ return LlmResponse(answer=response.response, session_id=session_id)
async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
"""Query answer from RAG App via web search asynchronously.
@@ -138,8 +143,6 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
if not query.question:
return LlmResponse(answer="Empty query. Please input your question.")
- session_id = correlation_id.get()
- self.logger.info(f"Get session ID: {session_id}.")
response = await self.agent.achat(query.question)
return LlmResponse(answer=response.response)
diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py
index ada3d29a..a7f0fc7d 100644
--- a/src/pai_rag/core/rag_service.py
+++ b/src/pai_rag/core/rag_service.py
@@ -47,27 +47,21 @@ def reload(self, new_config: Any):
self.rag.reload(self.rag_configuration.get_value())
self.rag_configuration.persist()
- @trace_correlation_id
async def add_knowledge(self, file_dir: str, enable_qa_extraction: bool = False):
await self.rag.load_knowledge(file_dir, enable_qa_extraction)
- @trace_correlation_id
async def aquery(self, query: RagQuery) -> RagResponse:
return await self.rag.aquery(query)
- @trace_correlation_id
async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
return await self.rag.aquery_llm(query)
- @trace_correlation_id
async def aquery_retrieval(self, query: RetrievalQuery):
return await self.rag.aquery_retrieval(query)
- @trace_correlation_id
async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
return await self.rag.aquery_agent(query)
- @trace_correlation_id
async def batch_evaluate_retrieval_and_response(self, type):
return await self.rag.batch_evaluate_retrieval_and_response(type)