diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index 61ab0b92..3e68c2de 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -1,6 +1,7 @@ import enum import warnings as wrngs from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator from functools import cached_property from typing import Generic, cast, overload @@ -80,7 +81,6 @@ async def generate_raw( Raw text response from LLM. """ options = (self.default_options | options) if options else self.default_options - response = await self.client.call( conversation=self._format_chat_for_llm(prompt), options=options, @@ -130,6 +130,32 @@ async def generate( return cast(OutputT, response) + async def generate_streaming( + self, + prompt: BasePrompt, + *, + options: LLMOptions | None = None, + ) -> AsyncGenerator[str, None]: + """ + Prepares and sends a prompt to the LLM and streams the results. + + Args: + prompt: Formatted prompt template with conversation. + options: Options to use for the LLM client. + + Returns: + Response stream from LLM. + """ + options = (self.default_options | options) if options else self.default_options + response = await self.client.call_streaming( + conversation=self._format_chat_for_llm(prompt), + options=options, + json_mode=prompt.json_mode, + output_schema=prompt.output_schema(), + ) + async for text_piece in response: + yield text_piece + def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat: if prompt.list_images(): wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}") diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/base.py b/packages/ragbits-core/src/ragbits/core/llms/clients/base.py index b22dec5a..b8df9dd9 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator from dataclasses import asdict, dataclass from typing import Any, ClassVar, Generic, TypeVar @@ -84,3 +85,24 @@ async def call( Returns: Response string from LLM. """ + + @abstractmethod + async def call_streaming( + self, + conversation: ChatFormat, + options: LLMClientOptions, + json_mode: bool = False, + output_schema: type[BaseModel] | dict | None = None, + ) -> AsyncGenerator[str, None]: + """ + Calls LLM inference API with output streaming. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + options: Additional settings used by LLM. + json_mode: Force the response to be in JSON format. + output_schema: Schema for structured response (either Pydantic model or a JSON schema). + + Returns: + Response stream from LLM. + """ diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py index 92efe65c..d599e516 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py @@ -1,9 +1,11 @@ +from collections.abc import AsyncGenerator from dataclasses import dataclass from pydantic import BaseModel try: import litellm + from litellm.utils import CustomStreamWrapper, ModelResponse HAS_LITELLM = True except ImportError: @@ -100,14 +102,7 @@ async def call( LLMStatusError: If the LLM API returns an error status code. LLMResponseError: If the LLM API response is invalid. """ - supported_params = litellm.get_supported_openai_params(model=self.model_name) - - response_format = None - if supported_params is not None and "response_format" in supported_params: - if output_schema is not None and self.use_structured_output: - response_format = output_schema - elif json_mode: - response_format = {"type": "json_object"} + response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode) with trace( messages=conversation, @@ -117,22 +112,9 @@ async def call( response_format=response_format, options=options.dict(), ) as outputs: - try: - response = await litellm.acompletion( - messages=conversation, - model=self.model_name, - base_url=self.base_url, - api_key=self.api_key, - api_version=self.api_version, - response_format=response_format, - **options.dict(), - ) - except litellm.openai.APIConnectionError as exc: - raise LLMConnectionError() from exc - except litellm.openai.APIStatusError as exc: - raise LLMStatusError(exc.message, exc.status_code) from exc - except litellm.openai.APIResponseValidationError as exc: - raise LLMResponseError() from exc + response = await self._get_litellm_response( + conversation=conversation, options=options, response_format=response_format + ) if not response.choices: # type: ignore raise LLMEmptyResponseError() @@ -144,3 +126,90 @@ async def call( outputs.total_tokens = response.usage.total_tokens # type: ignore return outputs.response # type: ignore + + async def call_streaming( + self, + conversation: ChatFormat, + options: LiteLLMOptions, + json_mode: bool = False, + output_schema: type[BaseModel] | dict | None = None, + ) -> AsyncGenerator[str, None]: + """ + Calls the appropriate LLM endpoint with the given prompt and options. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + options: Additional settings used by the LLM. + json_mode: Force the response to be in JSON format. + output_schema: Output schema for requesting a specific response format. + Only used if the client has been initialized with `use_structured_output=True`. + + Returns: + Response string from LLM. + + Raises: + LLMConnectionError: If there is a connection error with the LLM API. + LLMStatusError: If the LLM API returns an error status code. + LLMResponseError: If the LLM API response is invalid. + """ + response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode) + with trace( + messages=conversation, + model=self.model_name, + base_url=self.base_url, + api_version=self.api_version, + response_format=response_format, + options=options.dict(), + ) as outputs: + response = await self._get_litellm_response( + conversation=conversation, options=options, response_format=response_format, stream=True + ) + + if not response.completion_stream: # type: ignore + raise LLMEmptyResponseError() + + async def response_to_async_generator(response: CustomStreamWrapper) -> AsyncGenerator[str, None]: + async for item in response: + yield item.choices[0].delta.content or "" + + outputs.response = response_to_async_generator(response) # type: ignore + return outputs.response # type: ignore + + async def _get_litellm_response( + self, + conversation: ChatFormat, + options: LiteLLMOptions, + response_format: type[BaseModel] | dict | None, + stream: bool = False, + ) -> ModelResponse | CustomStreamWrapper: + try: + response = await litellm.acompletion( + messages=conversation, + model=self.model_name, + base_url=self.base_url, + api_key=self.api_key, + api_version=self.api_version, + response_format=response_format, + stream=stream, + **options.dict(), + ) + except litellm.openai.APIConnectionError as exc: + raise LLMConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise LLMStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise LLMResponseError() from exc + return response + + def _get_response_format( + self, output_schema: type[BaseModel] | dict | None, json_mode: bool + ) -> type[BaseModel] | dict | None: + supported_params = litellm.get_supported_openai_params(model=self.model_name) + + response_format = None + if supported_params is not None and "response_format" in supported_params: + if output_schema is not None and self.use_structured_output: + response_format = output_schema + elif json_mode: + response_format = {"type": "json_object"} + return response_format diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/local.py b/packages/ragbits-core/src/ragbits/core/llms/clients/local.py index 28f0987e..0d8ddee3 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/local.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/local.py @@ -1,10 +1,14 @@ +import asyncio +import threading +from collections.abc import AsyncGenerator from dataclasses import dataclass from pydantic import BaseModel try: + import accelerate # noqa: F401 import torch - from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer HAS_LOCAL_LLM = True except ImportError: @@ -99,3 +103,40 @@ async def call( response = outputs[0][input_ids.shape[-1] :] decoded_response = self.tokenizer.decode(response, skip_special_tokens=True) return decoded_response + + async def call_streaming( + self, + conversation: ChatFormat, + options: LocalLLMOptions, + json_mode: bool = False, + output_schema: type[BaseModel] | dict | None = None, + ) -> AsyncGenerator[str, None]: + """ + Makes a call to the local LLM with the provided prompt and options in streaming manner. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + options: Additional settings used by the LLM. + json_mode: Force the response to be in JSON format (not used). + output_schema: Output schema for requesting a specific response format (not used). + + Returns: + Async generator of tokens + """ + input_ids = self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, return_tensors="pt" + ).to(self.model.device) + streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True) + generation_kwargs = dict(streamer=streamer, **options.dict()) + generation_thread = threading.Thread(target=self.model.generate, args=(input_ids,), kwargs=generation_kwargs) + + async def streamer_to_async_generator( + streamer: TextIteratorStreamer, generation_thread: threading.Thread + ) -> AsyncGenerator[str, None]: + generation_thread.start() + for text_piece in streamer: + yield text_piece + await asyncio.sleep(0.0) + generation_thread.join() + + return streamer_to_async_generator(streamer=streamer, generation_thread=generation_thread)