Skip to content

Commit

Permalink
LLM: improve documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Dec 14, 2024
1 parent 87cdb72 commit e0b3889
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
22 changes: 18 additions & 4 deletions src/gened/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@


def _get_client(provider: LLMProvider, model: str, api_key: str) -> OpenAIClient:
""" Return a configured OpenAI client object (using OpenAI-compatible
endpoints for other providers) """
"""Create and configure an OpenAI-compatible client for the given provider.
Args:
provider: The LLM provider to use
model: The model identifier
api_key: The API key for authentication
Returns:
A configured OpenAIClient instance using the appropriate base URL for the provider
"""
match provider:
case 'google':
# https://ai.google.dev/gemini-api/docs/openai
Expand All @@ -31,14 +39,20 @@ def _get_client(provider: LLMProvider, model: str, api_key: str) -> OpenAIClient

@dataclass
class LLM:
"""Manages access to language models with token tracking and lazy client initialization."""
provider: LLMProvider
model: str
api_key: str
tokens_remaining: int | None = None # None if current user is not using tokens
_client: OpenAIClient | None = field(default=None, init=False, repr=False) # Instantiated only when needed

async def get_completion(self, prompt: str | None = None, messages: list[OpenAIChatMessage] | None = None) -> tuple[dict[str, str], str]:
""" Lazily instantiate the LLM client object only when used. """
"""Get a completion from the language model.
The client is lazily instantiated on first use.
Delegates to OpenAIClient.get_completion() (see openai_client.py)
"""
if self._client is None:
self._client = _get_client(self.provider, self.model, self.api_key)
return await self._client.get_completion(prompt, messages)
Expand Down Expand Up @@ -197,7 +211,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R:


def get_models() -> list[Row]:
"""Enumerate the models available in the database."""
"""Get all active language models from the database."""
db = get_db()
models = db.execute("SELECT * FROM models WHERE active ORDER BY id ASC").fetchall()
return models
27 changes: 23 additions & 4 deletions src/gened/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,39 @@


class OpenAIClient:
"""Client for interacting with OpenAI or compatible API endpoints."""

def __init__(self, model: str, api_key: str, *, base_url: str | None = None):
"""Initialize an OpenAI client.
Args:
model: The model identifier to use for completions
api_key: The API key for authentication
base_url: Optional base URL for non-OpenAI providers
"""
if base_url:
self._client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
self._client = openai.AsyncOpenAI(api_key=api_key)
self._model = model

async def get_completion(self, prompt: str | None = None, messages: list[OpenAIChatMessage] | None = None) -> tuple[dict[str, str], str]:
'''
"""Get a completion from the LLM.
Args:
prompt: A single prompt string (converted to a message if provided)
messages: A list of chat messages in OpenAI format
(Only one of prompt or messages should be provided)
Returns:
- A tuple containing:
- An OpenAI response object
A tuple containing:
- The raw API response as a dict
- The response text (stripped)
'''
Note:
If an error occurs, the dict will contain an 'error' key with the error details,
and the text will contain a user-friendly error message.
"""
common_error_text = "Error ({error_type}). Something went wrong with this query. The error has been logged, and we'll work on it. For now, please try again."
try:
if messages is None:
Expand Down

0 comments on commit e0b3889

Please sign in to comment.