Skip to content

Commit

Permalink
Adds serverless endpoints back (#445)
Browse files Browse the repository at this point in the history
* init

* adding serverless endpoints back

* updated tests
  • Loading branch information
clefourrier authored Dec 17, 2024
1 parent bc62cb9 commit 51ca581
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
## Endpoints-based Models
### InferenceEndpointModel
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig
[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel

### TGI ModelClient
Expand Down
19 changes: 12 additions & 7 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def inference_endpoint(
str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
free_endpoint: Annotated[
bool,
Option(
help="Use serverless free endpoints instead of spinning up your own inference endpoint.",
rich_help_panel=HELP_PANEL_NAME_4,
),
] = False,
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
Expand Down Expand Up @@ -200,9 +207,7 @@ def inference_endpoint(
"""

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModelConfig,
)
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand All @@ -220,10 +225,10 @@ def inference_endpoint(
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])

model_config = InferenceEndpointModelConfig.from_path(model_config_path)
if free_endpoint:
model_config = ServerlessEndpointModelConfig.from_path(model_config_path)
else:
model_config = InferenceEndpointModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down
29 changes: 21 additions & 8 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,18 @@


@dataclass
class InferenceModelConfig:
model: str
class ServerlessEndpointModelConfig:
model_name: str
add_special_tokens: bool = True

@classmethod
def from_path(cls, path: str) -> "ServerlessEndpointModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
return cls(**config["base_params"])


@dataclass
class InferenceEndpointModelConfig:
Expand Down Expand Up @@ -150,7 +158,7 @@ class InferenceEndpointModel(LightevalModel):
"""

def __init__( # noqa: C901
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig
) -> None:
self.reuse_existing = getattr(config, "reuse_existing", False)
self._max_length = None
Expand Down Expand Up @@ -280,10 +288,10 @@ def __init__( # noqa: C901
else: # Free inference client
self.endpoint = None
self.endpoint_name = None
self.name = config.model
self.name = config.model_name
self.revision = "default"
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
self.client = InferenceClient(model=config.model, token=env_config.token)
self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token)
self.client = InferenceClient(model=config.model_name, token=env_config.token)

self.use_async = True # set to False for debug - async use is faster

Expand All @@ -293,7 +301,7 @@ def __init__( # noqa: C901
self.model_info = ModelInfo(
model_name=self.name,
model_sha=self.revision,
model_dtype=config.model_dtype or "default",
model_dtype=getattr(config, "model_dtype", "default"),
model_size=-1,
)

Expand Down Expand Up @@ -545,7 +553,12 @@ def loglikelihood(
cont_toks = torch.tensor(cur_request.tokenized_continuation)
len_choice = len(cont_toks)

logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
if self.endpoint: # inference endpoint
logits = [
t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None
] # to check
else: # serverless endpoint
logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None]

greedy_tokens = torch.tensor(logits).argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all().squeeze(0)
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModel,
InferenceEndpointModelConfig,
InferenceModelConfig,
ServerlessEndpointModelConfig,
)
from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
Expand Down Expand Up @@ -80,7 +80,7 @@ def load_model( # noqa: C901
if isinstance(config, TGIModelConfig):
return load_model_with_tgi(config)

if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig):
if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig):
return load_model_with_inference_endpoints(config, env_config=env_config)

if isinstance(config, BaseModelConfig):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/endpoints/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TestInferenceEndpointModelConfig:
},
),
(
"examples/model_configs/endpoint_model_lite.yaml",
"examples/model_configs/serverless_model.yaml",
{
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
# Defaults:
Expand Down

0 comments on commit 51ca581

Please sign in to comment.