diff --git a/examples/model_configs/endpoint_model.yaml b/examples/model_configs/endpoint_model.yaml index c3f5222b..3cca5c43 100644 --- a/examples/model_configs/endpoint_model.yaml +++ b/examples/model_configs/endpoint_model.yaml @@ -1,7 +1,6 @@ model: base_params: - endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters - model: "meta-llama/Llama-2-7b-hf" + model_name: "meta-llama/Llama-2-7b-hf" # the model name or the endpoint name if reuse_existing is true revision: "main" dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16" reuse_existing: false # if true, ignore all params in instance, and don't delete the endpoint after evaluation diff --git a/examples/model_configs/endpoint_model_lite.yaml b/examples/model_configs/endpoint_model_lite.yaml new file mode 100644 index 00000000..af1652e1 --- /dev/null +++ b/examples/model_configs/endpoint_model_lite.yaml @@ -0,0 +1,3 @@ +model: + base_params: + model_name: "meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B" diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 877be2df..5069c414 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -190,7 +190,7 @@ def inference_endpoint( ] = None, override_batch_size: Annotated[ int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANNEL_NAME_3) - ] = -1, + ] = None, job_id: Annotated[ int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3) ] = 0, @@ -203,7 +203,6 @@ def inference_endpoint( from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.models.model_config import ( InferenceEndpointModelConfig, - InferenceModelConfig, ) from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters @@ -219,38 +218,34 @@ def inference_endpoint( # TODO (nathan): better handling of model_args - parallelism_manager = ParallelismManager.TGI + parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote with open(model_config_path, "r") as f: config = yaml.safe_load(f)["model"] - reuse_existing_endpoint = config["base_params"].get("reuse_existing", None) - - complete_config_endpoint = all( - val not in [None, ""] - for key, val in config.get("instance", {}).items() - if key not in InferenceEndpointModelConfig.nullable_keys() + # Find a way to add this back + # if config["base_params"].get("endpoint_name", None): + # return InferenceModelConfig(model=config["base_params"]["endpoint_name"]) + all_params = { + "model_name": config["base_params"].get("model_name", None), + "endpoint_name": config["base_params"].get("endpoint_name", None), + "model_dtype": config["base_params"].get("dtype", None), + "revision": config["base_params"].get("revision", None) or "main", + "should_reuse_existing": config["base_params"].get("should_reuse_existing"), + "accelerator": config.get("instance", {}).get("accelerator", None), + "region": config.get("instance", {}).get("region", None), + "vendor": config.get("instance", {}).get("vendor", None), + "instance_size": config.get("instance", {}).get("instance_size", None), + "instance_type": config.get("instance", {}).get("instance_type", None), + "namespace": config.get("instance", {}).get("namespace", None), + "image_url": config.get("instance", {}).get("image_url", None), + "env_vars": config.get("instance", {}).get("env_vars", None), + } + model_config = InferenceEndpointModelConfig( + # We only initialize params which have a non default value + **{k: v for k, v in all_params.items() if v is not None}, ) - if reuse_existing_endpoint or complete_config_endpoint: - model_config = InferenceEndpointModelConfig( - name=config["base_params"]["endpoint_name"].replace(".", "-").lower(), - repository=config["base_params"]["model"], - model_dtype=config["base_params"]["dtype"], - revision=config["base_params"]["revision"] or "main", - should_reuse_existing=reuse_existing_endpoint, - accelerator=config["instance"]["accelerator"], - region=config["instance"]["region"], - vendor=config["instance"]["vendor"], - instance_size=config["instance"]["instance_size"], - instance_type=config["instance"]["instance_type"], - namespace=config["instance"]["namespace"], - image_url=config["instance"].get("image_url", None), - env_vars=config["instance"].get("env_vars", None), - ) - else: - model_config = InferenceModelConfig(model=config["base_params"]["endpoint_name"]) - pipeline_params = PipelineParameters( launcher_type=parallelism_manager, env_config=env_config, diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index de3c2ba1..bc2c7eac 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -21,19 +21,25 @@ # SOFTWARE. import asyncio +import re +import time from typing import Coroutine, List, Optional, Union +import requests import torch from huggingface_hub import ( AsyncInferenceClient, InferenceClient, InferenceEndpoint, + InferenceEndpointError, InferenceEndpointTimeoutError, TextGenerationInputGrammarType, TextGenerationOutput, create_inference_endpoint, get_inference_endpoint, ) +from huggingface_hub.utils import HfHubHTTPError +from requests import ConnectionError from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoTokenizer @@ -53,6 +59,17 @@ BATCH_SIZE = 50 +MAX_TIME_FOR_SPINUP = 3600 + +SORTED_INSTANCE_SIZES = [ # sorted by incremental overall RAM (to load models) + # type, size + ("nvidia-a10g", "x1"), + ("nvidia-t4", "x4"), + ("nvidia-a100", "x1"), + ("nvidia-a10g", "x4"), + ("nvidia-a100", "x2"), + ("nvidia-a100", "x4"), +] class InferenceEndpointModel(LightevalModel): @@ -60,60 +77,137 @@ class InferenceEndpointModel(LightevalModel): endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation. """ - def __init__( + def __init__( # noqa: C901 self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig ) -> None: self.reuse_existing = getattr(config, "should_reuse_existing", True) self._max_length = None + self.endpoint = None + self.model_name = None if isinstance(config, InferenceEndpointModelConfig): - if config.should_reuse_existing: - self.endpoint = get_inference_endpoint( - name=config.name, token=env_config.token, namespace=config.namespace + if config.instance_type and config.instance_size and config.vendor and config.region: + vendor, region, instance_type, instance_size = ( + config.vendor, + config.region, + config.instance_type, + config.instance_size, ) else: - self.endpoint: InferenceEndpoint = create_inference_endpoint( - name=config.name, - namespace=config.namespace, - repository=config.repository, - revision=config.revision, - framework=config.framework, - task="text-generation", - accelerator=config.accelerator, - vendor=config.vendor, - region=config.region, - type=config.endpoint_type, - instance_size=config.instance_size, - instance_type=config.instance_type, - token=env_config.token, - custom_image={ - "health_route": "/health", - "env": { - # Documentaiton: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher - "MAX_BATCH_PREFILL_TOKENS": "2048", - "MAX_INPUT_LENGTH": "2047", - "MAX_TOTAL_TOKENS": "2048", - "MODEL_ID": "/repository", - "HF_MODEL_TRUST_REMOTE_CODE": "true", - **config.get_dtype_args(), - **config.get_custom_env_vars(), - }, - "url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"), - }, - ) - hlog("Deploying your endpoint. Please wait.") - try: - self.endpoint.wait(timeout=600) # Waits for the endpoint to be deployed - except InferenceEndpointTimeoutError as e: - hlog_err("Endpoint did not start within 10 minutes, there was a timeout.") - raise e + try: + vendor, region, instance_type, instance_size = InferenceEndpointModel.get_suggested_model_config( + config.model_name + ) + except Exception: + vendor, region, instance_type, instance_size = ( + "aws", + "us-east-1", + *InferenceEndpointModel.get_larger_hardware_suggestion(), + ) + + must_scaleup_endpoint = False + timer_start = time.time() + # Endpoint names do not allow special characters + endpoint_name = config.endpoint_name or re.sub( + "[^a-zA-Z0-9-]", "-", config.model_name.lower() + "-lighteval" + ) + # If no endpoint or endpoint not running, and we're below an hour + while (self.endpoint is None or self.endpoint.status != "running") and ( + time.time() - timer_start < MAX_TIME_FOR_SPINUP + ): + try: + if self.endpoint is None: # Endpoint does not exist yet locally + if not config.should_reuse_existing: # New endpoint + hlog("Creating endpoint.") + self.endpoint: InferenceEndpoint = create_inference_endpoint( + name=endpoint_name, + namespace=config.namespace, + repository=config.model_name, + revision=config.revision, + framework=config.framework, + task="text-generation", + accelerator=config.accelerator, + type=config.endpoint_type, + vendor=vendor, + region=region, + instance_size=instance_size, + instance_type=instance_type, + token=env_config.token, + custom_image={ + "health_route": "/health", + "env": { + # Documentation: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher + "MAX_BATCH_PREFILL_TOKENS": "2048", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "MODEL_ID": "/repository", + "HF_MODEL_TRUST_REMOTE_CODE": "true", + **config.get_dtype_args(), + **config.get_custom_env_vars(), + }, + "url": ( + config.image_url or "ghcr.io/huggingface/text-generation-inference:latest" + ), + }, + ) + else: # Endpoint exists + hlog("Reusing existing endpoint.") + self.endpoint = get_inference_endpoint( + name=endpoint_name, token=env_config.token, namespace=config.namespace + ) + + else: + # Endpoint exists locally but either failed (and most likely it must be scaled up) + if must_scaleup_endpoint: + hlog("Rescaling existing endpoint.") + self.endpoint.update(instance_size=instance_size, instance_type=instance_type) + must_scaleup_endpoint = False + # or we got a connection error, in which case we do nothing and just wait at the next step + + # Waits for the endpoint to be deployed - we could also check for the status in updating', 'pending', 'initializing' + hlog("Trying to deploy your endpoint. Please wait for 10 min.") + self.endpoint.wait(timeout=600, refresh_every=60) # We wait for 10 min + except InferenceEndpointError as e: + instance_type, instance_size = InferenceEndpointModel.get_larger_hardware_suggestion( + instance_type, instance_size + ) + must_scaleup_endpoint = True + + hlog( + f"Endpoint failed to start on current hardware with error {e}. Trying to autoscale to ({instance_type}, {instance_size})." + ) + except InferenceEndpointTimeoutError as e: + hlog_err("Endpoint did not start within 30 minutes, there was a timeout. Please inspect the logs.") + raise e + except HfHubHTTPError as e: + # The endpoint actually already exists, we'll spin it up instead of trying to create a new one + if "409 Client Error: Conflict for url:" in str(e): + config.endpoint_name = endpoint_name + config.should_reuse_existing = True + # Requested resources are not available + elif "Bad Request: Compute instance not available yet" in str(e): + hlog_err( + "The hardware combination you are requesting does not seem to be available: ({instance_type}, {instance_size}, {config.region})." + ) + raise e + # User account does not have access to requested resources + elif "Conflict: Quota exceeded" in str(e): + raise e + except ConnectionError as e: + hlog_err(f"Connection failed with error {e}. Retrying") + + if not self.endpoint.status == "running": + raise Exception("Did not manage to start endpoint within the elapsed time and on suggested hardware.") + hlog("Endpoint successfully deployed!") - self.name = config.repository + self.endpoint_name = config.endpoint_name + self.name = self.endpoint.repository self.revision = self.endpoint.revision self.async_client: AsyncInferenceClient = self.endpoint.async_client self.client: InferenceClient = self.endpoint.client else: # Free inference client self.endpoint = None + self.endpoint_name = None self.name = config.model self.revision = "default" self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token) @@ -131,6 +225,43 @@ def __init__( model_size=-1, ) + @staticmethod + def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None): + cur_instance_ix = -1 + try: + if cur_instance_type and cur_instance_size: + cur_instance_ix = SORTED_INSTANCE_SIZES.index((cur_instance_type, cur_instance_size)) + new_instance_type = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][0] + new_instance_size = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][1] + return new_instance_type, new_instance_size + except ValueError: + raise Exception( + f"Problem when scaling endpoint: the current instance combination ({cur_instance_type}, {cur_instance_size}) is unknown. Can't scale it up." + ) + except IndexError: + raise Exception( + "To avoid accidental costs, we do not upgrade the current endpoint above 4 a100 automatically, please request it explicitely." + ) + + @staticmethod + def get_suggested_model_config(model_repo): + # Code from https://huggingface.co/spaces/huggingface/dedicated-endpoint-snooper/blob/main/app.py + # Example of the suggestedCompute value: 'aws-us-east-1-nvidia-l4-x1' + # -> aws us-east-1 nvidia-l4 x1 + url = f"https://ui.endpoints.huggingface.co/api/configuration?model_id={model_repo}" + response = requests.get(url) + config = response.json() + + suggested_compute = config["suggestedCompute"] + suggested_vendor = suggested_compute.split("-")[0] + if suggested_vendor == "azure": + suggested_region = suggested_compute.split("-")[1] + else: + suggested_region = "-".join(suggested_compute.split("-")[1:4]) + suggested_instance = "-".join(suggested_compute.split("-")[-3:-1]) + suggested_size = suggested_compute.split("-")[-1] + return suggested_vendor, suggested_region, suggested_instance, suggested_size + @property def tokenizer(self): return self._tokenizer @@ -144,11 +275,17 @@ def disable_tqdm(self) -> bool: False # no accelerator = this is the main process def cleanup(self): - if self.endpoint is not None and not self.reuse_existing: - self.endpoint.delete() - hlog_warn( - "You deleted your endpoint after using it. You'll need to create it again if you need to reuse it." - ) + if self.endpoint is not None: + if self.reuse_existing: + self.endpoint.pause() + hlog_warn( + "Since your endpoint was existing before, we did not delete it, but paused it instead. You might want to delete it if you're done using it." + ) + else: + self.endpoint.delete() + hlog_warn( + "We deleted the spinned up endpoint after using it. You'll need to create it again if you need to reuse it." + ) @property def max_length(self): diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 2afaead1..268e2a6f 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -267,24 +267,36 @@ class InferenceModelConfig: @dataclass class InferenceEndpointModelConfig: - name: str - repository: str - accelerator: str - vendor: str - region: str - instance_size: str - instance_type: str - model_dtype: str + endpoint_name: str = None + model_name: str = None + should_reuse_existing: bool = False + accelerator: str = "gpu" + model_dtype: str = None # if empty, we use the default + vendor: str = "aws" + region: str = "us-east-1" # this region has the most hardware options available + instance_size: str = None # if none, we autoscale + instance_type: str = None # if none, we autoscale framework: str = "pytorch" endpoint_type: str = "protected" - should_reuse_existing: bool = False add_special_tokens: bool = True revision: str = "main" namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace image_url: str = None env_vars: dict = None + def __post_init__(self): + # xor operator, one is None but not the other + if (self.instance_size is None) ^ (self.instance_type is None): + raise ValueError( + "When creating an inference endpoint, you need to specify explicitely both instance_type and instance_size, or none of them for autoscaling." + ) + + if not (self.endpoint_name is None) ^ int(self.model_name is None): + raise ValueError("You need to set either endpoint_name or model_name (but not both).") + def get_dtype_args(self) -> Dict[str, str]: + if self.model_dtype is None: + return {} model_dtype = self.model_dtype.lower() if model_dtype in ["awq", "eetq", "gptq"]: return {"QUANTIZE": model_dtype} @@ -298,12 +310,3 @@ def get_dtype_args(self) -> Dict[str, str]: def get_custom_env_vars(self) -> Dict[str, str]: return {k: str(v) for k, v in self.env_vars.items()} if self.env_vars else {} - - @staticmethod - def nullable_keys() -> list[str]: - """ - Returns the list of optional keys in an endpoint model configuration. By default, the code requires that all the - keys be specified in the configuration in order to launch the endpoint. This function returns the list of keys - that are not required and can remain None. - """ - return ["namespace", "env_vars", "image_url"]