diff --git a/src/pai_rag/app/api/middleware.py b/src/pai_rag/app/api/middleware.py index 06d42009..13a29a85 100644 --- a/src/pai_rag/app/api/middleware.py +++ b/src/pai_rag/app/api/middleware.py @@ -31,7 +31,7 @@ def init_middleware(app: FastAPI): def _configure_session_middleware(app): app.add_middleware( CorrelationIdMiddleware, - header_name="X-Session-ID", + header_name="X-Request-ID", ) diff --git a/src/pai_rag/app/api/models.py b/src/pai_rag/app/api/models.py index 6cbe25c6..8cfede1e 100644 --- a/src/pai_rag/app/api/models.py +++ b/src/pai_rag/app/api/models.py @@ -8,12 +8,14 @@ class RagQuery(BaseModel): vector_topk: int | None = 3 score_threshold: float | None = 0.5 chat_history: List[Dict[str, str]] | None = None + session_id: str | None = None class LlmQuery(BaseModel): question: str temperature: float | None = 0.1 chat_history: List[Dict[str, str]] | None = None + session_id: str | None = None class RetrievalQuery(BaseModel): @@ -24,12 +26,14 @@ class RetrievalQuery(BaseModel): class RagResponse(BaseModel): answer: str + session_id: str | None = None # TODO # context: List[str] | None = None class LlmResponse(BaseModel): answer: str + session_id: str | None = None class ContextDoc(BaseModel): diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 7d034377..35e88508 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -44,13 +44,10 @@ def load_data_url(self): return f"{self.endpoint}service/data" def query(self, text: str, session_id: str = None): - q = dict(question=text) - r = requests.post(self.query_url, headers={"X-Session-ID": session_id}, json=q) + q = dict(question=text, session_id=session_id) + r = requests.post(self.query_url, json=q) r.raise_for_status() - session_id = r.headers["x-session-id"] response = dotdict(json.loads(r.text)) - response.session_id = session_id - return response def query_llm( @@ -59,13 +56,15 @@ def query_llm( session_id: str = None, temperature: float = 0.1, ): - q = dict(question=text, temperature=temperature) + q = dict( + question=text, + temperature=temperature, + session_id=session_id, + ) - r = requests.post(self.llm_url, headers={"X-Session-ID": session_id}, json=q) + r = requests.post(self.llm_url, json=q) r.raise_for_status() - session_id = r.headers["x-session-id"] response = dotdict(json.loads(r.text)) - response.session_id = session_id return response @@ -73,9 +72,7 @@ def query_vector(self, text: str): q = dict(question=text) r = requests.post(self.retrieval_url, json=q) r.raise_for_status() - session_id = r.headers["x-session-id"] response = dotdict(json.loads(r.text)) - response.session_id = session_id formatted_text = "DocumentScoreText\n" for i, doc in enumerate(response["docs"]): html_content = markdown.markdown(doc["text"]) diff --git a/src/pai_rag/app/web/tabs/chat_tab.py b/src/pai_rag/app/web/tabs/chat_tab.py index c6eee69d..8bb7998b 100644 --- a/src/pai_rag/app/web/tabs/chat_tab.py +++ b/src/pai_rag/app/web/tabs/chat_tab.py @@ -39,20 +39,20 @@ def respond(input_elements: List[Any]): msg = update_dict["question"] chatbot = update_dict["chatbot"] + if not update_dict["include_history"]: + current_session_id = None + if query_type == "LLM": response = rag_client.query_llm( msg, session_id=current_session_id, ) - + current_session_id = response.session_id elif query_type == "Retrieval": response = rag_client.query_vector(msg) else: response = rag_client.query(msg, session_id=current_session_id) - if update_dict["include_history"]: current_session_id = response.session_id - else: - current_session_id = None chatbot.append((msg, response.answer)) return "", chatbot, 0 diff --git a/src/pai_rag/app/web/tabs/vector_db_panel.py b/src/pai_rag/app/web/tabs/vector_db_panel.py index fc66df8a..ea3e175b 100644 --- a/src/pai_rag/app/web/tabs/vector_db_panel.py +++ b/src/pai_rag/app/web/tabs/vector_db_panel.py @@ -11,9 +11,7 @@ def create_vector_db_panel( components = [] with gr.Column(): with gr.Column(): - _ = gr.Markdown( - value=f"**Please check your Vector Store for {view_model.vectordb_type}.**" - ) + _ = gr.Markdown(value="**Please check your Vector Store.**") vectordb_type = gr.Dropdown( ["Hologres", "Milvus", "ElasticSearch", "AnalyticDB", "FAISS"], label="Which VectorStore do you want to use?", diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 17f4d6e4..8a842ca6 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -1,4 +1,3 @@ -from asgi_correlation_id import correlation_id from pai_rag.data.rag_dataloader import RagDataLoader from pai_rag.utils.oss_cache import OssCache from pai_rag.modules.module_registry import module_registry @@ -15,8 +14,11 @@ from llama_index.core.schema import QueryBundle import logging +from uuid import uuid4 -DEFAULT_SESSION_ID = "default" # For test-only + +def uuid_generator() -> str: + return uuid4().hex class RagApplication: @@ -63,8 +65,6 @@ async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse: if not query.question: return RetrievalResponse(docs=[]) - session_id = correlation_id.get() or DEFAULT_SESSION_ID - self.logger.info(f"Get session ID: {session_id}.") query_bundle = QueryBundle(query.question) node_results = await self.query_engine.aretrieve(query_bundle) @@ -89,17 +89,19 @@ async def aquery(self, query: RagQuery) -> RagResponse: Returns: RagResponse """ + session_id = query.session_id or uuid_generator() + self.logger.info(f"Get session ID: {session_id}.") if not query.question: - return RagResponse(answer="Empty query. Please input your question.") + return RagResponse( + answer="Empty query. Please input your question.", session_id=session_id + ) - session_id = correlation_id.get() or DEFAULT_SESSION_ID - self.logger.info(f"Get session ID: {session_id}.") query_chat_engine = self.chat_engine_factory.get_chat_engine( session_id, query.chat_history ) response = await query_chat_engine.achat(query.question) self.chat_store.persist() - return RagResponse(answer=response.response) + return RagResponse(answer=response.response, session_id=session_id) async def aquery_llm(self, query: LlmQuery) -> LlmResponse: """Query answer from LLM response asynchronously. @@ -112,17 +114,20 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse: Returns: LlmResponse """ + session_id = query.session_id or uuid_generator() + self.logger.info(f"Get session ID: {session_id}.") + if not query.question: - return LlmResponse(answer="Empty query. Please input your question.") + return LlmResponse( + answer="Empty query. Please input your question.", session_id=session_id + ) - session_id = correlation_id.get() or DEFAULT_SESSION_ID - self.logger.info(f"Get session ID: {session_id}.") llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine( session_id, query.chat_history ) response = await llm_chat_engine.achat(query.question) self.chat_store.persist() - return LlmResponse(answer=response.response) + return LlmResponse(answer=response.response, session_id=session_id) async def aquery_agent(self, query: LlmQuery) -> LlmResponse: """Query answer from RAG App via web search asynchronously. @@ -138,8 +143,6 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse: if not query.question: return LlmResponse(answer="Empty query. Please input your question.") - session_id = correlation_id.get() - self.logger.info(f"Get session ID: {session_id}.") response = await self.agent.achat(query.question) return LlmResponse(answer=response.response) diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index ada3d29a..a7f0fc7d 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -47,27 +47,21 @@ def reload(self, new_config: Any): self.rag.reload(self.rag_configuration.get_value()) self.rag_configuration.persist() - @trace_correlation_id async def add_knowledge(self, file_dir: str, enable_qa_extraction: bool = False): await self.rag.load_knowledge(file_dir, enable_qa_extraction) - @trace_correlation_id async def aquery(self, query: RagQuery) -> RagResponse: return await self.rag.aquery(query) - @trace_correlation_id async def aquery_llm(self, query: LlmQuery) -> LlmResponse: return await self.rag.aquery_llm(query) - @trace_correlation_id async def aquery_retrieval(self, query: RetrievalQuery): return await self.rag.aquery_retrieval(query) - @trace_correlation_id async def aquery_agent(self, query: LlmQuery) -> LlmResponse: return await self.rag.aquery_agent(query) - @trace_correlation_id async def batch_evaluate_retrieval_and_response(self, type): return await self.rag.batch_evaluate_retrieval_and_response(type)