Skip to content

Commit

Permalink
Fix session-id (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 authored Jun 7, 2024
1 parent 2756bfb commit 48b64a9
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/pai_rag/app/api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def init_middleware(app: FastAPI):
def _configure_session_middleware(app):
app.add_middleware(
CorrelationIdMiddleware,
header_name="X-Session-ID",
header_name="X-Request-ID",
)


Expand Down
4 changes: 4 additions & 0 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ class RagQuery(BaseModel):
vector_topk: int | None = 3
score_threshold: float | None = 0.5
chat_history: List[Dict[str, str]] | None = None
session_id: str | None = None


class LlmQuery(BaseModel):
question: str
temperature: float | None = 0.1
chat_history: List[Dict[str, str]] | None = None
session_id: str | None = None


class RetrievalQuery(BaseModel):
Expand All @@ -24,12 +26,14 @@ class RetrievalQuery(BaseModel):

class RagResponse(BaseModel):
answer: str
session_id: str | None = None
# TODO
# context: List[str] | None = None


class LlmResponse(BaseModel):
answer: str
session_id: str | None = None


class ContextDoc(BaseModel):
Expand Down
19 changes: 8 additions & 11 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@ def load_data_url(self):
return f"{self.endpoint}service/data"

def query(self, text: str, session_id: str = None):
q = dict(question=text)
r = requests.post(self.query_url, headers={"X-Session-ID": session_id}, json=q)
q = dict(question=text, session_id=session_id)
r = requests.post(self.query_url, json=q)
r.raise_for_status()
session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
response.session_id = session_id

return response

def query_llm(
Expand All @@ -59,23 +56,23 @@ def query_llm(
session_id: str = None,
temperature: float = 0.1,
):
q = dict(question=text, temperature=temperature)
q = dict(
question=text,
temperature=temperature,
session_id=session_id,
)

r = requests.post(self.llm_url, headers={"X-Session-ID": session_id}, json=q)
r = requests.post(self.llm_url, json=q)
r.raise_for_status()
session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
response.session_id = session_id

return response

def query_vector(self, text: str):
q = dict(question=text)
r = requests.post(self.retrieval_url, json=q)
r.raise_for_status()
session_id = r.headers["x-session-id"]
response = dotdict(json.loads(r.text))
response.session_id = session_id
formatted_text = "<tr><th>Document</th><th>Score</th><th>Text</th></tr>\n"
for i, doc in enumerate(response["docs"]):
html_content = markdown.markdown(doc["text"])
Expand Down
8 changes: 4 additions & 4 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ def respond(input_elements: List[Any]):
msg = update_dict["question"]
chatbot = update_dict["chatbot"]

if not update_dict["include_history"]:
current_session_id = None

if query_type == "LLM":
response = rag_client.query_llm(
msg,
session_id=current_session_id,
)

current_session_id = response.session_id
elif query_type == "Retrieval":
response = rag_client.query_vector(msg)
else:
response = rag_client.query(msg, session_id=current_session_id)
if update_dict["include_history"]:
current_session_id = response.session_id
else:
current_session_id = None
chatbot.append((msg, response.answer))
return "", chatbot, 0

Expand Down
4 changes: 1 addition & 3 deletions src/pai_rag/app/web/tabs/vector_db_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def create_vector_db_panel(
components = []
with gr.Column():
with gr.Column():
_ = gr.Markdown(
value=f"**Please check your Vector Store for {view_model.vectordb_type}.**"
)
_ = gr.Markdown(value="**Please check your Vector Store.**")
vectordb_type = gr.Dropdown(
["Hologres", "Milvus", "ElasticSearch", "AnalyticDB", "FAISS"],
label="Which VectorStore do you want to use?",
Expand Down
31 changes: 17 additions & 14 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from asgi_correlation_id import correlation_id
from pai_rag.data.rag_dataloader import RagDataLoader
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.module_registry import module_registry
Expand All @@ -15,8 +14,11 @@
from llama_index.core.schema import QueryBundle

import logging
from uuid import uuid4

DEFAULT_SESSION_ID = "default" # For test-only

def uuid_generator() -> str:
return uuid4().hex


class RagApplication:
Expand Down Expand Up @@ -63,8 +65,6 @@ async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RetrievalResponse(docs=[])

session_id = correlation_id.get() or DEFAULT_SESSION_ID
self.logger.info(f"Get session ID: {session_id}.")
query_bundle = QueryBundle(query.question)
node_results = await self.query_engine.aretrieve(query_bundle)

Expand All @@ -89,17 +89,19 @@ async def aquery(self, query: RagQuery) -> RagResponse:
Returns:
RagResponse
"""
session_id = query.session_id or uuid_generator()
self.logger.info(f"Get session ID: {session_id}.")
if not query.question:
return RagResponse(answer="Empty query. Please input your question.")
return RagResponse(
answer="Empty query. Please input your question.", session_id=session_id
)

session_id = correlation_id.get() or DEFAULT_SESSION_ID
self.logger.info(f"Get session ID: {session_id}.")
query_chat_engine = self.chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await query_chat_engine.achat(query.question)
self.chat_store.persist()
return RagResponse(answer=response.response)
return RagResponse(answer=response.response, session_id=session_id)

async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
"""Query answer from LLM response asynchronously.
Expand All @@ -112,17 +114,20 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
Returns:
LlmResponse
"""
session_id = query.session_id or uuid_generator()
self.logger.info(f"Get session ID: {session_id}.")

if not query.question:
return LlmResponse(answer="Empty query. Please input your question.")
return LlmResponse(
answer="Empty query. Please input your question.", session_id=session_id
)

session_id = correlation_id.get() or DEFAULT_SESSION_ID
self.logger.info(f"Get session ID: {session_id}.")
llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await llm_chat_engine.achat(query.question)
self.chat_store.persist()
return LlmResponse(answer=response.response)
return LlmResponse(answer=response.response, session_id=session_id)

async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
"""Query answer from RAG App via web search asynchronously.
Expand All @@ -138,8 +143,6 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
if not query.question:
return LlmResponse(answer="Empty query. Please input your question.")

session_id = correlation_id.get()
self.logger.info(f"Get session ID: {session_id}.")
response = await self.agent.achat(query.question)
return LlmResponse(answer=response.response)

Expand Down
6 changes: 0 additions & 6 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,21 @@ def reload(self, new_config: Any):
self.rag.reload(self.rag_configuration.get_value())
self.rag_configuration.persist()

@trace_correlation_id
async def add_knowledge(self, file_dir: str, enable_qa_extraction: bool = False):
await self.rag.load_knowledge(file_dir, enable_qa_extraction)

@trace_correlation_id
async def aquery(self, query: RagQuery) -> RagResponse:
return await self.rag.aquery(query)

@trace_correlation_id
async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
return await self.rag.aquery_llm(query)

@trace_correlation_id
async def aquery_retrieval(self, query: RetrievalQuery):
return await self.rag.aquery_retrieval(query)

@trace_correlation_id
async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
return await self.rag.aquery_agent(query)

@trace_correlation_id
async def batch_evaluate_retrieval_and_response(self, type):
return await self.rag.batch_evaluate_retrieval_and_response(type)

Expand Down

0 comments on commit 48b64a9

Please sign in to comment.