Skip to content

Commit

Permalink
Merge branch 'main' into clem_homogeneize_generation_params
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Dec 18, 2024
2 parents 6a18b81 + 51ca581 commit 83cbb10
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
## Endpoints-based Models
### InferenceEndpointModel
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig
[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel

### TGI ModelClient
Expand Down
19 changes: 12 additions & 7 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def inference_endpoint(
str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
free_endpoint: Annotated[
bool,
Option(
help="Use serverless free endpoints instead of spinning up your own inference endpoint.",
rich_help_panel=HELP_PANEL_NAME_4,
),
] = False,
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
Expand Down Expand Up @@ -211,9 +218,7 @@ def inference_endpoint(
Evaluate models using inference-endpoints as backend.
"""
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModelConfig,
)
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand All @@ -231,10 +236,10 @@ def inference_endpoint(
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])

model_config = InferenceEndpointModelConfig.from_path(model_config_path)
if free_endpoint:
model_config = ServerlessEndpointModelConfig.from_path(model_config_path)
else:
model_config = InferenceEndpointModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down
33 changes: 22 additions & 11 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,23 @@


@dataclass
class InferenceModelConfig:
model: str
class ServerlessEndpointModelConfig:
model_name: str
add_special_tokens: bool = True
generation_parameters: GenerationParameters = None

def __post_init__(self):
if not self.generation_parameters:
self.generation_parameters = GenerationParameters()

@classmethod
def from_path(cls, path: str) -> "ServerlessEndpointModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
return cls(**config["base_params"])


@dataclass
class InferenceEndpointModelConfig:
Expand Down Expand Up @@ -161,7 +169,7 @@ class InferenceEndpointModel(LightevalModel):
"""

def __init__( # noqa: C901
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig
) -> None:
self.reuse_existing = getattr(config, "reuse_existing", False)
self._max_length = None
Expand Down Expand Up @@ -227,9 +235,7 @@ def __init__( # noqa: C901
**config.get_dtype_args(),
**config.get_custom_env_vars(),
},
"url": (
config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"
),
"url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:3.0.1"),
},
)
else: # Endpoint exists
Expand Down Expand Up @@ -293,10 +299,10 @@ def __init__( # noqa: C901
else: # Free inference client
self.endpoint = None
self.endpoint_name = None
self.name = config.model
self.name = config.model_name
self.revision = "default"
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
self.client = InferenceClient(model=config.model, token=env_config.token)
self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token)
self.client = InferenceClient(model=config.model_name, token=env_config.token)

self.use_async = True # set to False for debug - async use is faster

Expand All @@ -306,7 +312,7 @@ def __init__( # noqa: C901
self.model_info = ModelInfo(
model_name=self.name,
model_sha=self.revision,
model_dtype=config.model_dtype or "default",
model_dtype=getattr(config, "model_dtype", "default"),
model_size=-1,
)
self.generation_parameters = config.generation_parameters
Expand Down Expand Up @@ -567,7 +573,12 @@ def loglikelihood(
cont_toks = torch.tensor(cur_request.tokenized_continuation)
len_choice = len(cont_toks)

logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
if self.endpoint: # inference endpoint
logits = [
t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None
] # to check
else: # serverless endpoint
logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None]

greedy_tokens = torch.tensor(logits).argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all().squeeze(0)
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModel,
InferenceEndpointModelConfig,
InferenceModelConfig,
ServerlessEndpointModelConfig,
)
from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
Expand Down Expand Up @@ -80,7 +80,7 @@ def load_model( # noqa: C901
if isinstance(config, TGIModelConfig):
return load_model_with_tgi(config)

if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig):
if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig):
return load_model_with_inference_endpoints(config, env_config=env_config)

if isinstance(config, TransformersModelConfig):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/endpoints/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TestInferenceEndpointModelConfig:
},
),
(
"examples/model_configs/endpoint_model_lite.yaml",
"examples/model_configs/serverless_model.yaml",
{
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
# Defaults:
Expand Down

0 comments on commit 83cbb10

Please sign in to comment.