diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index 12145faa..a5fd7485 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -5,18 +5,11 @@ RagQuery, LlmQuery, RetrievalQuery, - RagResponse, LlmResponse, DataInput, ) from fastapi.responses import StreamingResponse -from typing import Generator, AsyncGenerator -from fastapi import APIRouter, Security, Response -from fastapi.responses import StreamingResponse -from llama_index.core.base.llms.types import ( - ChatResponse, -) -import json + router = APIRouter() @@ -26,11 +19,13 @@ async def aquery(query: RagQuery): if not query.stream: return response else: + async def event_generator(): full_response = "" async for token in response.async_response_gen(): full_response = full_response + token yield token + return StreamingResponse( event_generator(), media_type="text/event-stream", @@ -43,16 +38,19 @@ async def aquery_llm(query: LlmQuery): if not query.stream: return response else: + async def event_generator(): full_response = "" async for token in response.async_response_gen(): full_response = full_response + token yield token + return StreamingResponse( event_generator(), media_type="text/event-stream", ) + @router.post("/query/retrieval") async def aquery_retrieval(query: RetrievalQuery): return await rag_service.aquery_retrieval(query) diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index b1b3b52b..6018274c 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -41,7 +41,12 @@ def config_url(self): def load_data_url(self): return f"{self.endpoint}service/data" - def query(self, text: str, session_id: str = None, stream: bool = False,): + def query( + self, + text: str, + session_id: str = None, + stream: bool = False, + ): q = dict(question=text, stream=stream) r = requests.post(self.query_url, headers={"X-Session-ID": session_id}, json=q) r.raise_for_status() @@ -61,7 +66,13 @@ def query_llm( eas_llm_top_k: float = 30, stream: bool = False, ): - q = dict(question=text, topp=top_p, topk=eas_llm_top_k, temperature=temperature, stream=stream) + q = dict( + question=text, + topp=top_p, + topk=eas_llm_top_k, + temperature=temperature, + stream=stream, + ) r = requests.post(self.llm_url, headers={"X-Session-ID": session_id}, json=q) r.raise_for_status() diff --git a/src/pai_rag/app/web/ui.py b/src/pai_rag/app/web/ui.py index 36a5b37e..8d8bb7ea 100644 --- a/src/pai_rag/app/web/ui.py +++ b/src/pai_rag/app/web/ui.py @@ -126,17 +126,17 @@ def respond(input_elements: List[Any]): msg = update_dict["question"] chatbot = update_dict["chatbot"] is_streaming = update_dict["is_streaming"] - + if query_type == "LLM": response = rag_client.query_llm( - text=msg, - session_id=current_session_id, - stream=is_streaming + text=msg, session_id=current_session_id, stream=is_streaming ) elif query_type == "Retrieval": response = rag_client.query_vector(msg) else: - response = rag_client.query(text=msg, session_id=current_session_id,stream=is_streaming) + response = rag_client.query( + text=msg, session_id=current_session_id, stream=is_streaming + ) if update_dict["include_history"]: current_session_id = response.session_id else: diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 96efcb94..67e45314 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -7,7 +7,6 @@ RagQuery, LlmQuery, RetrievalQuery, - RagResponse, LlmResponse, ContextDoc, RetrievalResponse, diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index 8571916a..7d4ad786 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -6,7 +6,6 @@ RagQuery, LlmQuery, RetrievalQuery, - RagResponse, LlmResponse, ) from pai_rag.app.web.view_model import view_model diff --git a/src/pai_rag/modules/synthesizer/my_simple_synthesizer.py b/src/pai_rag/modules/synthesizer/my_simple_synthesizer.py index 68de5f27..35886920 100644 --- a/src/pai_rag/modules/synthesizer/my_simple_synthesizer.py +++ b/src/pai_rag/modules/synthesizer/my_simple_synthesizer.py @@ -10,6 +10,8 @@ from llama_index.core.service_context import ServiceContext from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType from llama_index.core.types import RESPONSE_TEXT_TYPE + + class MySimpleSummarize(BaseSynthesizer): def __init__( self, @@ -31,13 +33,16 @@ def __init__( streaming=streaming, ) self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL + def _get_prompts(self) -> PromptDictType: """Get prompts.""" return {"text_qa_template": self._text_qa_template} + def _update_prompts(self, prompts: PromptDictType) -> None: """Update prompts.""" if "text_qa_template" in prompts: self._text_qa_template = prompts["text_qa_template"] + async def aget_response( self, query_str: str, @@ -69,6 +74,7 @@ async def aget_response( else: response = cast(Generator, response) return response + def get_response( self, query_str: str, @@ -98,4 +104,4 @@ def get_response( response = response or "Empty Response" else: response = cast(Generator, response) - return response \ No newline at end of file + return response diff --git a/src/pai_rag/modules/synthesizer/synthesizer.py b/src/pai_rag/modules/synthesizer/synthesizer.py index e3165196..aea39f03 100644 --- a/src/pai_rag/modules/synthesizer/synthesizer.py +++ b/src/pai_rag/modules/synthesizer/synthesizer.py @@ -21,7 +21,6 @@ Refine, CompactAndRefine, TreeSummarize, - SimpleSummarize, ) # from llama_index.core.response_synthesizers.type import ResponseMode