Skip to content

Commit

Permalink
support stream v1
Browse files Browse the repository at this point in the history
  • Loading branch information
wwxxzz committed Jun 4, 2024
1 parent d7854d4 commit 0adf250
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 19 deletions.
14 changes: 6 additions & 8 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,11 @@
RagQuery,
LlmQuery,
RetrievalQuery,
RagResponse,
LlmResponse,
DataInput,
)
from fastapi.responses import StreamingResponse
from typing import Generator, AsyncGenerator
from fastapi import APIRouter, Security, Response
from fastapi.responses import StreamingResponse
from llama_index.core.base.llms.types import (
ChatResponse,
)
import json

router = APIRouter()


Expand All @@ -26,11 +19,13 @@ async def aquery(query: RagQuery):
if not query.stream:
return response
else:

async def event_generator():
full_response = ""
async for token in response.async_response_gen():
full_response = full_response + token
yield token

return StreamingResponse(
event_generator(),
media_type="text/event-stream",
Expand All @@ -43,16 +38,19 @@ async def aquery_llm(query: LlmQuery):
if not query.stream:
return response
else:

async def event_generator():
full_response = ""
async for token in response.async_response_gen():
full_response = full_response + token
yield token

return StreamingResponse(
event_generator(),
media_type="text/event-stream",
)


@router.post("/query/retrieval")
async def aquery_retrieval(query: RetrievalQuery):
return await rag_service.aquery_retrieval(query)
Expand Down
15 changes: 13 additions & 2 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def config_url(self):
def load_data_url(self):
return f"{self.endpoint}service/data"

def query(self, text: str, session_id: str = None, stream: bool = False,):
def query(
self,
text: str,
session_id: str = None,
stream: bool = False,
):
q = dict(question=text, stream=stream)
r = requests.post(self.query_url, headers={"X-Session-ID": session_id}, json=q)
r.raise_for_status()
Expand All @@ -61,7 +66,13 @@ def query_llm(
eas_llm_top_k: float = 30,
stream: bool = False,
):
q = dict(question=text, topp=top_p, topk=eas_llm_top_k, temperature=temperature, stream=stream)
q = dict(
question=text,
topp=top_p,
topk=eas_llm_top_k,
temperature=temperature,
stream=stream,
)

r = requests.post(self.llm_url, headers={"X-Session-ID": session_id}, json=q)
r.raise_for_status()
Expand Down
10 changes: 5 additions & 5 deletions src/pai_rag/app/web/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,17 @@ def respond(input_elements: List[Any]):
msg = update_dict["question"]
chatbot = update_dict["chatbot"]
is_streaming = update_dict["is_streaming"]

if query_type == "LLM":
response = rag_client.query_llm(
text=msg,
session_id=current_session_id,
stream=is_streaming
text=msg, session_id=current_session_id, stream=is_streaming
)
elif query_type == "Retrieval":
response = rag_client.query_vector(msg)
else:
response = rag_client.query(text=msg, session_id=current_session_id,stream=is_streaming)
response = rag_client.query(
text=msg, session_id=current_session_id, stream=is_streaming
)
if update_dict["include_history"]:
current_session_id = response.session_id
else:
Expand Down
1 change: 0 additions & 1 deletion src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
RagQuery,
LlmQuery,
RetrievalQuery,
RagResponse,
LlmResponse,
ContextDoc,
RetrievalResponse,
Expand Down
1 change: 0 additions & 1 deletion src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
RagQuery,
LlmQuery,
RetrievalQuery,
RagResponse,
LlmResponse,
)
from pai_rag.app.web.view_model import view_model
Expand Down
8 changes: 7 additions & 1 deletion src/pai_rag/modules/synthesizer/my_simple_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from llama_index.core.service_context import ServiceContext
from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType
from llama_index.core.types import RESPONSE_TEXT_TYPE


class MySimpleSummarize(BaseSynthesizer):
def __init__(
self,
Expand All @@ -31,13 +33,16 @@ def __init__(
streaming=streaming,
)
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL

def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"text_qa_template": self._text_qa_template}

def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "text_qa_template" in prompts:
self._text_qa_template = prompts["text_qa_template"]

async def aget_response(
self,
query_str: str,
Expand Down Expand Up @@ -69,6 +74,7 @@ async def aget_response(
else:
response = cast(Generator, response)
return response

def get_response(
self,
query_str: str,
Expand Down Expand Up @@ -98,4 +104,4 @@ def get_response(
response = response or "Empty Response"
else:
response = cast(Generator, response)
return response
return response
1 change: 0 additions & 1 deletion src/pai_rag/modules/synthesizer/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Refine,
CompactAndRefine,
TreeSummarize,
SimpleSummarize,
)

# from llama_index.core.response_synthesizers.type import ResponseMode
Expand Down

0 comments on commit 0adf250

Please sign in to comment.