diff --git a/docs/source/package_reference/models.mdx b/docs/source/package_reference/models.mdx index 096ce7be..dcf5bc8d 100644 --- a/docs/source/package_reference/models.mdx +++ b/docs/source/package_reference/models.mdx @@ -21,7 +21,7 @@ ## Endpoints-based Models ### InferenceEndpointModel [[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig -[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig +[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig [[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel ### TGI ModelClient diff --git a/examples/model_configs/endpoint_model_lite.yaml b/examples/model_configs/serverless_model.yaml similarity index 100% rename from examples/model_configs/endpoint_model_lite.yaml rename to examples/model_configs/serverless_model.yaml diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 0b291f59..f992d65c 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -146,6 +146,13 @@ def inference_endpoint( str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)") ], tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], + free_endpoint: Annotated[ + bool, + Option( + help="Use serverless free endpoints instead of spinning up your own inference endpoint.", + rich_help_panel=HELP_PANEL_NAME_4, + ), + ] = False, # === Common parameters === use_chat_template: Annotated[ bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) @@ -200,9 +207,7 @@ def inference_endpoint( """ from lighteval.logging.evaluation_tracker import EvaluationTracker - from lighteval.models.endpoints.endpoint_model import ( - InferenceEndpointModelConfig, - ) + from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) @@ -220,10 +225,10 @@ def inference_endpoint( parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote # Find a way to add this back - # if config["base_params"].get("endpoint_name", None): - # return InferenceModelConfig(model=config["base_params"]["endpoint_name"]) - - model_config = InferenceEndpointModelConfig.from_path(model_config_path) + if free_endpoint: + model_config = ServerlessEndpointModelConfig.from_path(model_config_path) + else: + model_config = InferenceEndpointModelConfig.from_path(model_config_path) pipeline_params = PipelineParameters( launcher_type=parallelism_manager, diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index e50c0405..80798b61 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -75,10 +75,18 @@ @dataclass -class InferenceModelConfig: - model: str +class ServerlessEndpointModelConfig: + model_name: str add_special_tokens: bool = True + @classmethod + def from_path(cls, path: str) -> "ServerlessEndpointModelConfig": + import yaml + + with open(path, "r") as f: + config = yaml.safe_load(f)["model"] + return cls(**config["base_params"]) + @dataclass class InferenceEndpointModelConfig: @@ -150,7 +158,7 @@ class InferenceEndpointModel(LightevalModel): """ def __init__( # noqa: C901 - self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig + self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig ) -> None: self.reuse_existing = getattr(config, "reuse_existing", False) self._max_length = None @@ -280,10 +288,10 @@ def __init__( # noqa: C901 else: # Free inference client self.endpoint = None self.endpoint_name = None - self.name = config.model + self.name = config.model_name self.revision = "default" - self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token) - self.client = InferenceClient(model=config.model, token=env_config.token) + self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token) + self.client = InferenceClient(model=config.model_name, token=env_config.token) self.use_async = True # set to False for debug - async use is faster @@ -293,7 +301,7 @@ def __init__( # noqa: C901 self.model_info = ModelInfo( model_name=self.name, model_sha=self.revision, - model_dtype=config.model_dtype or "default", + model_dtype=getattr(config, "model_dtype", "default"), model_size=-1, ) @@ -545,7 +553,12 @@ def loglikelihood( cont_toks = torch.tensor(cur_request.tokenized_continuation) len_choice = len(cont_toks) - logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None] + if self.endpoint: # inference endpoint + logits = [ + t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None + ] # to check + else: # serverless endpoint + logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None] greedy_tokens = torch.tensor(logits).argmax(dim=-1) max_equal = (greedy_tokens == cont_toks).all().squeeze(0) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 30aec21c..dff3b9b4 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -27,7 +27,7 @@ from lighteval.models.endpoints.endpoint_model import ( InferenceEndpointModel, InferenceEndpointModelConfig, - InferenceModelConfig, + ServerlessEndpointModelConfig, ) from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig @@ -84,7 +84,7 @@ def load_model( # noqa: C901 if isinstance(config, TGIModelConfig): return load_model_with_tgi(config) - if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig): + if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig): return load_model_with_inference_endpoints(config, env_config=env_config) if isinstance(config, BaseModelConfig): diff --git a/tests/models/endpoints/test_endpoint_model.py b/tests/models/endpoints/test_endpoint_model.py index 29fbb3c4..f4ba15d9 100644 --- a/tests/models/endpoints/test_endpoint_model.py +++ b/tests/models/endpoints/test_endpoint_model.py @@ -53,7 +53,7 @@ class TestInferenceEndpointModelConfig: }, ), ( - "examples/model_configs/endpoint_model_lite.yaml", + "examples/model_configs/serverless_model.yaml", { "model_name": "meta-llama/Llama-3.1-8B-Instruct", # Defaults: