Skip to content

Commit

Permalink
Autoscaling inference endpoints (#412)
Browse files Browse the repository at this point in the history
* adding better management for restarts and resizes
* upgraded autoscale
* added pause option
* fix to parallelism manager - no need for endpoint
  • Loading branch information
clefourrier authored Dec 5, 2024
1 parent 3929825 commit b68d5bc
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 93 deletions.
3 changes: 1 addition & 2 deletions examples/model_configs/endpoint_model.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
model:
base_params:
endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
model: "meta-llama/Llama-2-7b-hf"
model_name: "meta-llama/Llama-2-7b-hf" # the model name or the endpoint name if reuse_existing is true
revision: "main"
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
reuse_existing: false # if true, ignore all params in instance, and don't delete the endpoint after evaluation
Expand Down
3 changes: 3 additions & 0 deletions examples/model_configs/endpoint_model_lite.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model:
base_params:
model_name: "meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B"
51 changes: 23 additions & 28 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def inference_endpoint(
] = None,
override_batch_size: Annotated[
int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANNEL_NAME_3)
] = -1,
] = None,
job_id: Annotated[
int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3)
] = 0,
Expand All @@ -203,7 +203,6 @@ def inference_endpoint(
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.model_config import (
InferenceEndpointModelConfig,
InferenceModelConfig,
)
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

Expand All @@ -219,38 +218,34 @@ def inference_endpoint(

# TODO (nathan): better handling of model_args

parallelism_manager = ParallelismManager.TGI
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

with open(model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]

reuse_existing_endpoint = config["base_params"].get("reuse_existing", None)

complete_config_endpoint = all(
val not in [None, ""]
for key, val in config.get("instance", {}).items()
if key not in InferenceEndpointModelConfig.nullable_keys()
# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
all_params = {
"model_name": config["base_params"].get("model_name", None),
"endpoint_name": config["base_params"].get("endpoint_name", None),
"model_dtype": config["base_params"].get("dtype", None),
"revision": config["base_params"].get("revision", None) or "main",
"should_reuse_existing": config["base_params"].get("should_reuse_existing"),
"accelerator": config.get("instance", {}).get("accelerator", None),
"region": config.get("instance", {}).get("region", None),
"vendor": config.get("instance", {}).get("vendor", None),
"instance_size": config.get("instance", {}).get("instance_size", None),
"instance_type": config.get("instance", {}).get("instance_type", None),
"namespace": config.get("instance", {}).get("namespace", None),
"image_url": config.get("instance", {}).get("image_url", None),
"env_vars": config.get("instance", {}).get("env_vars", None),
}
model_config = InferenceEndpointModelConfig(
# We only initialize params which have a non default value
**{k: v for k, v in all_params.items() if v is not None},
)

if reuse_existing_endpoint or complete_config_endpoint:
model_config = InferenceEndpointModelConfig(
name=config["base_params"]["endpoint_name"].replace(".", "-").lower(),
repository=config["base_params"]["model"],
model_dtype=config["base_params"]["dtype"],
revision=config["base_params"]["revision"] or "main",
should_reuse_existing=reuse_existing_endpoint,
accelerator=config["instance"]["accelerator"],
region=config["instance"]["region"],
vendor=config["instance"]["vendor"],
instance_size=config["instance"]["instance_size"],
instance_type=config["instance"]["instance_type"],
namespace=config["instance"]["namespace"],
image_url=config["instance"].get("image_url", None),
env_vars=config["instance"].get("env_vars", None),
)
else:
model_config = InferenceModelConfig(model=config["base_params"]["endpoint_name"])

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
env_config=env_config,
Expand Down
227 changes: 182 additions & 45 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@
# SOFTWARE.

import asyncio
import re
import time
from typing import Coroutine, List, Optional, Union

import requests
import torch
from huggingface_hub import (
AsyncInferenceClient,
InferenceClient,
InferenceEndpoint,
InferenceEndpointError,
InferenceEndpointTimeoutError,
TextGenerationInputGrammarType,
TextGenerationOutput,
create_inference_endpoint,
get_inference_endpoint,
)
from huggingface_hub.utils import HfHubHTTPError
from requests import ConnectionError
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer
Expand All @@ -53,67 +59,155 @@


BATCH_SIZE = 50
MAX_TIME_FOR_SPINUP = 3600

SORTED_INSTANCE_SIZES = [ # sorted by incremental overall RAM (to load models)
# type, size
("nvidia-a10g", "x1"),
("nvidia-t4", "x4"),
("nvidia-a100", "x1"),
("nvidia-a10g", "x4"),
("nvidia-a100", "x2"),
("nvidia-a100", "x4"),
]


class InferenceEndpointModel(LightevalModel):
"""InferenceEndpointModels can be used both with the free inference client, or with inference
endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation.
"""

def __init__(
def __init__( # noqa: C901
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
) -> None:
self.reuse_existing = getattr(config, "should_reuse_existing", True)
self._max_length = None
self.endpoint = None
self.model_name = None
if isinstance(config, InferenceEndpointModelConfig):
if config.should_reuse_existing:
self.endpoint = get_inference_endpoint(
name=config.name, token=env_config.token, namespace=config.namespace
if config.instance_type and config.instance_size and config.vendor and config.region:
vendor, region, instance_type, instance_size = (
config.vendor,
config.region,
config.instance_type,
config.instance_size,
)
else:
self.endpoint: InferenceEndpoint = create_inference_endpoint(
name=config.name,
namespace=config.namespace,
repository=config.repository,
revision=config.revision,
framework=config.framework,
task="text-generation",
accelerator=config.accelerator,
vendor=config.vendor,
region=config.region,
type=config.endpoint_type,
instance_size=config.instance_size,
instance_type=config.instance_type,
token=env_config.token,
custom_image={
"health_route": "/health",
"env": {
# Documentaiton: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher
"MAX_BATCH_PREFILL_TOKENS": "2048",
"MAX_INPUT_LENGTH": "2047",
"MAX_TOTAL_TOKENS": "2048",
"MODEL_ID": "/repository",
"HF_MODEL_TRUST_REMOTE_CODE": "true",
**config.get_dtype_args(),
**config.get_custom_env_vars(),
},
"url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"),
},
)
hlog("Deploying your endpoint. Please wait.")
try:
self.endpoint.wait(timeout=600) # Waits for the endpoint to be deployed
except InferenceEndpointTimeoutError as e:
hlog_err("Endpoint did not start within 10 minutes, there was a timeout.")
raise e
try:
vendor, region, instance_type, instance_size = InferenceEndpointModel.get_suggested_model_config(
config.model_name
)
except Exception:
vendor, region, instance_type, instance_size = (
"aws",
"us-east-1",
*InferenceEndpointModel.get_larger_hardware_suggestion(),
)

must_scaleup_endpoint = False
timer_start = time.time()
# Endpoint names do not allow special characters
endpoint_name = config.endpoint_name or re.sub(
"[^a-zA-Z0-9-]", "-", config.model_name.lower() + "-lighteval"
)
# If no endpoint or endpoint not running, and we're below an hour
while (self.endpoint is None or self.endpoint.status != "running") and (
time.time() - timer_start < MAX_TIME_FOR_SPINUP
):
try:
if self.endpoint is None: # Endpoint does not exist yet locally
if not config.should_reuse_existing: # New endpoint
hlog("Creating endpoint.")
self.endpoint: InferenceEndpoint = create_inference_endpoint(
name=endpoint_name,
namespace=config.namespace,
repository=config.model_name,
revision=config.revision,
framework=config.framework,
task="text-generation",
accelerator=config.accelerator,
type=config.endpoint_type,
vendor=vendor,
region=region,
instance_size=instance_size,
instance_type=instance_type,
token=env_config.token,
custom_image={
"health_route": "/health",
"env": {
# Documentation: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher
"MAX_BATCH_PREFILL_TOKENS": "2048",
"MAX_INPUT_LENGTH": "2047",
"MAX_TOTAL_TOKENS": "2048",
"MODEL_ID": "/repository",
"HF_MODEL_TRUST_REMOTE_CODE": "true",
**config.get_dtype_args(),
**config.get_custom_env_vars(),
},
"url": (
config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"
),
},
)
else: # Endpoint exists
hlog("Reusing existing endpoint.")
self.endpoint = get_inference_endpoint(
name=endpoint_name, token=env_config.token, namespace=config.namespace
)

else:
# Endpoint exists locally but either failed (and most likely it must be scaled up)
if must_scaleup_endpoint:
hlog("Rescaling existing endpoint.")
self.endpoint.update(instance_size=instance_size, instance_type=instance_type)
must_scaleup_endpoint = False
# or we got a connection error, in which case we do nothing and just wait at the next step

# Waits for the endpoint to be deployed - we could also check for the status in updating', 'pending', 'initializing'
hlog("Trying to deploy your endpoint. Please wait for 10 min.")
self.endpoint.wait(timeout=600, refresh_every=60) # We wait for 10 min
except InferenceEndpointError as e:
instance_type, instance_size = InferenceEndpointModel.get_larger_hardware_suggestion(
instance_type, instance_size
)
must_scaleup_endpoint = True

hlog(
f"Endpoint failed to start on current hardware with error {e}. Trying to autoscale to ({instance_type}, {instance_size})."
)
except InferenceEndpointTimeoutError as e:
hlog_err("Endpoint did not start within 30 minutes, there was a timeout. Please inspect the logs.")
raise e
except HfHubHTTPError as e:
# The endpoint actually already exists, we'll spin it up instead of trying to create a new one
if "409 Client Error: Conflict for url:" in str(e):
config.endpoint_name = endpoint_name
config.should_reuse_existing = True
# Requested resources are not available
elif "Bad Request: Compute instance not available yet" in str(e):
hlog_err(
"The hardware combination you are requesting does not seem to be available: ({instance_type}, {instance_size}, {config.region})."
)
raise e
# User account does not have access to requested resources
elif "Conflict: Quota exceeded" in str(e):
raise e
except ConnectionError as e:
hlog_err(f"Connection failed with error {e}. Retrying")

if not self.endpoint.status == "running":
raise Exception("Did not manage to start endpoint within the elapsed time and on suggested hardware.")

hlog("Endpoint successfully deployed!")
self.name = config.repository
self.endpoint_name = config.endpoint_name
self.name = self.endpoint.repository
self.revision = self.endpoint.revision
self.async_client: AsyncInferenceClient = self.endpoint.async_client
self.client: InferenceClient = self.endpoint.client

else: # Free inference client
self.endpoint = None
self.endpoint_name = None
self.name = config.model
self.revision = "default"
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
Expand All @@ -131,6 +225,43 @@ def __init__(
model_size=-1,
)

@staticmethod
def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None):
cur_instance_ix = -1
try:
if cur_instance_type and cur_instance_size:
cur_instance_ix = SORTED_INSTANCE_SIZES.index((cur_instance_type, cur_instance_size))
new_instance_type = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][0]
new_instance_size = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][1]
return new_instance_type, new_instance_size
except ValueError:
raise Exception(
f"Problem when scaling endpoint: the current instance combination ({cur_instance_type}, {cur_instance_size}) is unknown. Can't scale it up."
)
except IndexError:
raise Exception(
"To avoid accidental costs, we do not upgrade the current endpoint above 4 a100 automatically, please request it explicitely."
)

@staticmethod
def get_suggested_model_config(model_repo):
# Code from https://huggingface.co/spaces/huggingface/dedicated-endpoint-snooper/blob/main/app.py
# Example of the suggestedCompute value: 'aws-us-east-1-nvidia-l4-x1'
# -> aws us-east-1 nvidia-l4 x1
url = f"https://ui.endpoints.huggingface.co/api/configuration?model_id={model_repo}"
response = requests.get(url)
config = response.json()

suggested_compute = config["suggestedCompute"]
suggested_vendor = suggested_compute.split("-")[0]
if suggested_vendor == "azure":
suggested_region = suggested_compute.split("-")[1]
else:
suggested_region = "-".join(suggested_compute.split("-")[1:4])
suggested_instance = "-".join(suggested_compute.split("-")[-3:-1])
suggested_size = suggested_compute.split("-")[-1]
return suggested_vendor, suggested_region, suggested_instance, suggested_size

@property
def tokenizer(self):
return self._tokenizer
Expand All @@ -144,11 +275,17 @@ def disable_tqdm(self) -> bool:
False # no accelerator = this is the main process

def cleanup(self):
if self.endpoint is not None and not self.reuse_existing:
self.endpoint.delete()
hlog_warn(
"You deleted your endpoint after using it. You'll need to create it again if you need to reuse it."
)
if self.endpoint is not None:
if self.reuse_existing:
self.endpoint.pause()
hlog_warn(
"Since your endpoint was existing before, we did not delete it, but paused it instead. You might want to delete it if you're done using it."
)
else:
self.endpoint.delete()
hlog_warn(
"We deleted the spinned up endpoint after using it. You'll need to create it again if you need to reuse it."
)

@property
def max_length(self):
Expand Down
Loading

0 comments on commit b68d5bc

Please sign in to comment.