diff --git a/docs/source/package_reference/models.mdx b/docs/source/package_reference/models.mdx index 01066fb60..9feed4652 100644 --- a/docs/source/package_reference/models.mdx +++ b/docs/source/package_reference/models.mdx @@ -21,7 +21,7 @@ ## Endpoints-based Models ### InferenceEndpointModel [[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig -[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig +[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig [[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel ### TGI ModelClient diff --git a/examples/model_configs/endpoint_model_lite.yaml b/examples/model_configs/serverless_model.yaml similarity index 100% rename from examples/model_configs/endpoint_model_lite.yaml rename to examples/model_configs/serverless_model.yaml diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index ccfccd0aa..07ce19c23 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -158,6 +158,13 @@ def inference_endpoint( str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)") ], tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], + free_endpoint: Annotated[ + bool, + Option( + help="Use serverless free endpoints instead of spinning up your own inference endpoint.", + rich_help_panel=HELP_PANEL_NAME_4, + ), + ] = False, # === Common parameters === use_chat_template: Annotated[ bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) @@ -211,9 +218,7 @@ def inference_endpoint( Evaluate models using inference-endpoints as backend. """ from lighteval.logging.evaluation_tracker import EvaluationTracker - from lighteval.models.endpoints.endpoint_model import ( - InferenceEndpointModelConfig, - ) + from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) @@ -231,10 +236,10 @@ def inference_endpoint( parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote # Find a way to add this back - # if config["base_params"].get("endpoint_name", None): - # return InferenceModelConfig(model=config["base_params"]["endpoint_name"]) - - model_config = InferenceEndpointModelConfig.from_path(model_config_path) + if free_endpoint: + model_config = ServerlessEndpointModelConfig.from_path(model_config_path) + else: + model_config = InferenceEndpointModelConfig.from_path(model_config_path) pipeline_params = PipelineParameters( launcher_type=parallelism_manager, diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 2c9a0929a..47978adff 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -77,8 +77,8 @@ @dataclass -class InferenceModelConfig: - model: str +class ServerlessEndpointModelConfig: + model_name: str add_special_tokens: bool = True generation_parameters: GenerationParameters = None @@ -86,6 +86,14 @@ def __post_init__(self): if not self.generation_parameters: self.generation_parameters = GenerationParameters() + @classmethod + def from_path(cls, path: str) -> "ServerlessEndpointModelConfig": + import yaml + + with open(path, "r") as f: + config = yaml.safe_load(f)["model"] + return cls(**config["base_params"]) + @dataclass class InferenceEndpointModelConfig: @@ -161,7 +169,7 @@ class InferenceEndpointModel(LightevalModel): """ def __init__( # noqa: C901 - self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig + self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig ) -> None: self.reuse_existing = getattr(config, "reuse_existing", False) self._max_length = None @@ -227,9 +235,7 @@ def __init__( # noqa: C901 **config.get_dtype_args(), **config.get_custom_env_vars(), }, - "url": ( - config.image_url or "ghcr.io/huggingface/text-generation-inference:latest" - ), + "url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:3.0.1"), }, ) else: # Endpoint exists @@ -293,10 +299,10 @@ def __init__( # noqa: C901 else: # Free inference client self.endpoint = None self.endpoint_name = None - self.name = config.model + self.name = config.model_name self.revision = "default" - self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token) - self.client = InferenceClient(model=config.model, token=env_config.token) + self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token) + self.client = InferenceClient(model=config.model_name, token=env_config.token) self.use_async = True # set to False for debug - async use is faster @@ -306,7 +312,7 @@ def __init__( # noqa: C901 self.model_info = ModelInfo( model_name=self.name, model_sha=self.revision, - model_dtype=config.model_dtype or "default", + model_dtype=getattr(config, "model_dtype", "default"), model_size=-1, ) self.generation_parameters = config.generation_parameters @@ -567,7 +573,12 @@ def loglikelihood( cont_toks = torch.tensor(cur_request.tokenized_continuation) len_choice = len(cont_toks) - logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None] + if self.endpoint: # inference endpoint + logits = [ + t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None + ] # to check + else: # serverless endpoint + logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None] greedy_tokens = torch.tensor(logits).argmax(dim=-1) max_equal = (greedy_tokens == cont_toks).all().squeeze(0) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index bd256472e..3a5542384 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -27,7 +27,7 @@ from lighteval.models.endpoints.endpoint_model import ( InferenceEndpointModel, InferenceEndpointModelConfig, - InferenceModelConfig, + ServerlessEndpointModelConfig, ) from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig @@ -80,7 +80,7 @@ def load_model( # noqa: C901 if isinstance(config, TGIModelConfig): return load_model_with_tgi(config) - if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig): + if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig): return load_model_with_inference_endpoints(config, env_config=env_config) if isinstance(config, TransformersModelConfig): diff --git a/tests/models/endpoints/test_endpoint_model.py b/tests/models/endpoints/test_endpoint_model.py index 29fbb3c48..f4ba15d91 100644 --- a/tests/models/endpoints/test_endpoint_model.py +++ b/tests/models/endpoints/test_endpoint_model.py @@ -53,7 +53,7 @@ class TestInferenceEndpointModelConfig: }, ), ( - "examples/model_configs/endpoint_model_lite.yaml", + "examples/model_configs/serverless_model.yaml", { "model_name": "meta-llama/Llama-3.1-8B-Instruct", # Defaults: