Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api, backend): load and expose backend model info at runtime #890

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions src/leapfrogai_api/backend/rag/leapfrogai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import leapfrogai_sdk as lfai
from leapfrogai_api.utils import get_model_config
from leapfrogai_api.utils.__init__ import config as global_config
from leapfrogai_api.backend.grpc_client import create_embeddings
import logging

Expand Down Expand Up @@ -59,8 +59,8 @@ async def _get_model(
Raises:
ValueError: If the embeddings model is not found.
"""

if not (model := get_model_config().get_model_backend(model=model_name)):
config = await global_config.create()
if not (model := config.get_model_backend(model=model_name)):
logging.error(f"Embeddings model {model_name} not found.")
raise ValueError("Embeddings model not found.")

Expand Down
172 changes: 164 additions & 8 deletions src/leapfrogai_api/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import datetime
from enum import Enum
from typing import Literal
from typing import Any, Literal
import warnings

from fastapi import UploadFile, Form, File
from openai.types import FileObject
Expand All @@ -18,7 +19,7 @@
)
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
from openai.types.beta.vector_store import ExpiresAfter
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator, ValidationInfo, ConfigDict

##########
# DEFAULTS
Expand All @@ -27,7 +28,7 @@

DEFAULT_MAX_COMPLETION_TOKENS = 4096
DEFAULT_MAX_PROMPT_TOKENS = 4096

DEFAULT_DIMENSION: int = 2000

##########
# GENERIC
Expand All @@ -49,11 +50,160 @@ class Usage(BaseModel):
)


##########
# WARNINGS
##########


class LeapfrogAIWarning(UserWarning):
"""Base warning class for LeapfrogAI."""


class DimensionWarning(LeapfrogAIWarning):
"""Warning for dimension-related issues."""

def __init__(
self,
dimension: int,
dimension_default: int = DEFAULT_DIMENSION,
):
super().__init__(
f"Dimension {dimension} exceeds the recommended maximum of {dimension_default}."
)


class CapabilityWarning(LeapfrogAIWarning):
"""Warning for capability-related issues."""


##########
# MODELS
##########


class Modality(str, Enum):
"""Defines the input/output modality of a model (image, text, speech, or null)."""

IMAGE = "image"
TEXT = "text"
SPEECH = "speech"


class Capability(str, Enum):
"""Specifies the functional capabilities of a model (chat, embeddings, speech-to-text, text-to-speech, or null)."""

CHAT = "chat"
EMBEDDINGS = "embeddings"
STT = "stt"
TTS = "tts"


class Precision(str, Enum):
"""Indicates the numerical precision used in the model's parameters (float16, float32, bfloat16, int8, int4, or null)."""

FLOAT16 = "float16"
FLOAT32 = "float32"
BFLOAT16 = "bfloat16"
INT8 = "int8"
INT4 = "int4"


class Format(str, Enum):
"""Describes the storage or quantization format of the model (None, GPTQ, GGUF, SqueezeLLM, AWQ, or null)."""

AWQ = "AWQ"
GGUF = "GGUF"
GPTQ = "GPTQ"
SQUEEZELLM = "SqueezeLLM"


class ModelMetadataResponse(BaseModel):
"""Metadata for the model, including type, dimensions (for embeddings), and precision."""

model_config = ConfigDict(use_enum_values=True)

capabilities: list[Capability] | None = Field(
default=None, # TODO: should we define this as an empty string if it's None?
description="Model capabilities (e.g., 'embeddings', 'chat', 'tts', 'stt')",
)
dimensions: int | None = Field(
jamestexas marked this conversation as resolved.
Show resolved Hide resolved
default=None,
description="Embedding dimensions (for embeddings models)",
)
format: Format | None = Field(
default=None,
description="Model format (e.g., None, 'GPTQ', 'GGUF', 'SqueezeLLM', 'AWQ')",
)
modalities: list[Modality] | None = Field(
default=None, # TODO: should we define this as an empty string if it's None?
description="The modalities of the model (e.g., 'image', 'text', 'speech')",
)

precision: Precision | None = Field(
default=None,
description="Model precision (e.g., 'float16', 'float32', 'bfloat16', 'int8', 'int4')",
)
type: Literal["embeddings", "llm"] | None = Field(
default=None,
description="The type of the model e.g. ('embeddings' or 'llm')",
)

@field_validator("dimensions")
@classmethod
def check_dimensions(
cls,
v: int | None,
info: ValidationInfo,
) -> int | None:
"""
Validates the 'dimensions' field of a model's metadata.

Args:
v: The dimension value to be validated.
info: The validation information.

Returns:
The validated dimension value.

Raises:
CapabilityError: If the 'dimensions' field is not set for models with 'embeddings' capability or vice versa.
"""
if v is not None and v > 2000:
warnings.warn(DimensionWarning(dimension=v))
return v

@field_validator("capabilities")
@classmethod
def validate_capabilities(
cls,
v: list[Capability] | None,
values: dict[str, Any],
) -> list[Capability] | None:
"""
Validates the 'capabilities' field of a model's metadata, ensuring that 'dimensions'
is correctly set when 'embeddings' is in capabilities.
"""
# TODO: Actually error here when 'embeddings' is not in capabilities, once this is actually implemented

# Check if dimensions is set, warn if 'embeddings' is not in capabilities
if (dimensions_set := (values.get("dimensions", None) is not None)) and not (
embeddings_set := (v is not None and "embeddings" in v)
):
warnings.warn(
CapabilityWarning(
"'dimensions' should only be set for models with 'embeddings' capability"
)
)
# Check if dimensions is not set, warn if 'embeddings' is in capabilities
elif not dimensions_set and embeddings_set:
warnings.warn(
CapabilityWarning(
"'dimensions' must be set for models with 'embeddings' capability"
)
)
return v


class ModelResponseModel(BaseModel):
"""Represents a single model in the response."""

Expand All @@ -75,6 +225,10 @@ class ModelResponseModel(BaseModel):
default="leapfrogai",
description="The organization that owns the model. Always 'leapfrogai' for LeapfrogAI models.",
)
metadata: ModelMetadataResponse | None = Field(
default=None,
description="Metadata for the model, including type, dimensions (for embeddings), and precision.",
)


class ModelResponse(BaseModel):
Expand All @@ -100,7 +254,7 @@ class CompletionRequest(BaseModel):
model: str = Field(
...,
description="The ID of the model to use for completion.",
example="llama-cpp-python",
examples=["llama-cpp-python"],
)
prompt: str | list[int] = Field(
...,
Expand Down Expand Up @@ -132,7 +286,9 @@ class CompletionChoice(BaseModel):
description="Log probabilities for the generated tokens. Only returned if requested.",
)
finish_reason: str = Field(
"", description="The reason why the completion finished.", example="length"
"",
description="The reason why the completion finished.",
examples=["length"],
)


Expand Down Expand Up @@ -579,12 +735,12 @@ class CreateVectorStoreRequest(BaseModel):
file_ids: list[str] | None = Field(
default=[],
description="List of file IDs to be included in the vector store.",
example=["file-abc123", "file-def456"],
examples=["file-abc123", "file-def456"],
)
name: str | None = Field(
default=None,
description="Optional name for the vector store.",
example="My Vector Store",
examples=["My Vector Store"],
)
expires_after: ExpiresAfter | None = Field(
default=None,
Expand All @@ -594,7 +750,7 @@ class CreateVectorStoreRequest(BaseModel):
metadata: dict | None = Field(
default=None,
description="Optional metadata for the vector store.",
example={"project": "AI Research", "version": "1.0"},
examples=[{"project": "AI Research", "version": "1.0"}],
)

def add_days_to_timestamp(self, timestamp: int, days: int) -> int:
Expand Down
Loading
Loading