Skip to content

Commit

Permalink
feat(core): streaming API for LLMs (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
kdziedzic68 authored Nov 19, 2024
1 parent c453926 commit 1947000
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 26 deletions.
28 changes: 27 additions & 1 deletion packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import warnings as wrngs
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Generic, cast, overload

Expand Down Expand Up @@ -80,7 +81,6 @@ async def generate_raw(
Raw text response from LLM.
"""
options = (self.default_options | options) if options else self.default_options

response = await self.client.call(
conversation=self._format_chat_for_llm(prompt),
options=options,
Expand Down Expand Up @@ -130,6 +130,32 @@ async def generate(

return cast(OutputT, response)

async def generate_streaming(
self,
prompt: BasePrompt,
*,
options: LLMOptions | None = None,
) -> AsyncGenerator[str, None]:
"""
Prepares and sends a prompt to the LLM and streams the results.
Args:
prompt: Formatted prompt template with conversation.
options: Options to use for the LLM client.
Returns:
Response stream from LLM.
"""
options = (self.default_options | options) if options else self.default_options
response = await self.client.call_streaming(
conversation=self._format_chat_for_llm(prompt),
options=options,
json_mode=prompt.json_mode,
output_schema=prompt.output_schema(),
)
async for text_piece in response:
yield text_piece

def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat:
if prompt.list_images():
wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}")
Expand Down
22 changes: 22 additions & 0 deletions packages/ragbits-core/src/ragbits/core/llms/clients/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import asdict, dataclass
from typing import Any, ClassVar, Generic, TypeVar

Expand Down Expand Up @@ -84,3 +85,24 @@ async def call(
Returns:
Response string from LLM.
"""

@abstractmethod
async def call_streaming(
self,
conversation: ChatFormat,
options: LLMClientOptions,
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> AsyncGenerator[str, None]:
"""
Calls LLM inference API with output streaming.
Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
options: Additional settings used by LLM.
json_mode: Force the response to be in JSON format.
output_schema: Schema for structured response (either Pydantic model or a JSON schema).
Returns:
Response stream from LLM.
"""
117 changes: 93 additions & 24 deletions packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import AsyncGenerator
from dataclasses import dataclass

from pydantic import BaseModel

try:
import litellm
from litellm.utils import CustomStreamWrapper, ModelResponse

HAS_LITELLM = True
except ImportError:
Expand Down Expand Up @@ -100,14 +102,7 @@ async def call(
LLMStatusError: If the LLM API returns an error status code.
LLMResponseError: If the LLM API response is invalid.
"""
supported_params = litellm.get_supported_openai_params(model=self.model_name)

response_format = None
if supported_params is not None and "response_format" in supported_params:
if output_schema is not None and self.use_structured_output:
response_format = output_schema
elif json_mode:
response_format = {"type": "json_object"}
response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode)

with trace(
messages=conversation,
Expand All @@ -117,22 +112,9 @@ async def call(
response_format=response_format,
options=options.dict(),
) as outputs:
try:
response = await litellm.acompletion(
messages=conversation,
model=self.model_name,
base_url=self.base_url,
api_key=self.api_key,
api_version=self.api_version,
response_format=response_format,
**options.dict(),
)
except litellm.openai.APIConnectionError as exc:
raise LLMConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise LLMStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise LLMResponseError() from exc
response = await self._get_litellm_response(
conversation=conversation, options=options, response_format=response_format
)

if not response.choices: # type: ignore
raise LLMEmptyResponseError()
Expand All @@ -144,3 +126,90 @@ async def call(
outputs.total_tokens = response.usage.total_tokens # type: ignore

return outputs.response # type: ignore

async def call_streaming(
self,
conversation: ChatFormat,
options: LiteLLMOptions,
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> AsyncGenerator[str, None]:
"""
Calls the appropriate LLM endpoint with the given prompt and options.
Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
options: Additional settings used by the LLM.
json_mode: Force the response to be in JSON format.
output_schema: Output schema for requesting a specific response format.
Only used if the client has been initialized with `use_structured_output=True`.
Returns:
Response string from LLM.
Raises:
LLMConnectionError: If there is a connection error with the LLM API.
LLMStatusError: If the LLM API returns an error status code.
LLMResponseError: If the LLM API response is invalid.
"""
response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode)
with trace(
messages=conversation,
model=self.model_name,
base_url=self.base_url,
api_version=self.api_version,
response_format=response_format,
options=options.dict(),
) as outputs:
response = await self._get_litellm_response(
conversation=conversation, options=options, response_format=response_format, stream=True
)

if not response.completion_stream: # type: ignore
raise LLMEmptyResponseError()

async def response_to_async_generator(response: CustomStreamWrapper) -> AsyncGenerator[str, None]:
async for item in response:
yield item.choices[0].delta.content or ""

outputs.response = response_to_async_generator(response) # type: ignore
return outputs.response # type: ignore

async def _get_litellm_response(
self,
conversation: ChatFormat,
options: LiteLLMOptions,
response_format: type[BaseModel] | dict | None,
stream: bool = False,
) -> ModelResponse | CustomStreamWrapper:
try:
response = await litellm.acompletion(
messages=conversation,
model=self.model_name,
base_url=self.base_url,
api_key=self.api_key,
api_version=self.api_version,
response_format=response_format,
stream=stream,
**options.dict(),
)
except litellm.openai.APIConnectionError as exc:
raise LLMConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise LLMStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise LLMResponseError() from exc
return response

def _get_response_format(
self, output_schema: type[BaseModel] | dict | None, json_mode: bool
) -> type[BaseModel] | dict | None:
supported_params = litellm.get_supported_openai_params(model=self.model_name)

response_format = None
if supported_params is not None and "response_format" in supported_params:
if output_schema is not None and self.use_structured_output:
response_format = output_schema
elif json_mode:
response_format = {"type": "json_object"}
return response_format
43 changes: 42 additions & 1 deletion packages/ragbits-core/src/ragbits/core/llms/clients/local.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import threading
from collections.abc import AsyncGenerator
from dataclasses import dataclass

from pydantic import BaseModel

try:
import accelerate # noqa: F401
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

HAS_LOCAL_LLM = True
except ImportError:
Expand Down Expand Up @@ -99,3 +103,40 @@ async def call(
response = outputs[0][input_ids.shape[-1] :]
decoded_response = self.tokenizer.decode(response, skip_special_tokens=True)
return decoded_response

async def call_streaming(
self,
conversation: ChatFormat,
options: LocalLLMOptions,
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> AsyncGenerator[str, None]:
"""
Makes a call to the local LLM with the provided prompt and options in streaming manner.
Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
options: Additional settings used by the LLM.
json_mode: Force the response to be in JSON format (not used).
output_schema: Output schema for requesting a specific response format (not used).
Returns:
Async generator of tokens
"""
input_ids = self.tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
generation_kwargs = dict(streamer=streamer, **options.dict())
generation_thread = threading.Thread(target=self.model.generate, args=(input_ids,), kwargs=generation_kwargs)

async def streamer_to_async_generator(
streamer: TextIteratorStreamer, generation_thread: threading.Thread
) -> AsyncGenerator[str, None]:
generation_thread.start()
for text_piece in streamer:
yield text_piece
await asyncio.sleep(0.0)
generation_thread.join()

return streamer_to_async_generator(streamer=streamer, generation_thread=generation_thread)

0 comments on commit 1947000

Please sign in to comment.