Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better prompt and message for litelm #453

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
17783b2
Added inference using litellm.
JoelNiklaus Nov 7, 2024
9e92150
Add Udmurt (udm) translation literals (#381)
codemurt Nov 8, 2024
30a624c
This PR adds translation literals for Belarusian language. (#382)
Kryuski Nov 8, 2024
6e6fed6
fix: cache directory variable (#378)
NazimHAli Nov 8, 2024
d1d4c69
greedy_until() fix (#344)
vsabolcec Nov 8, 2024
f69811f
Fixed some params in completion call to enable more model providers.
JoelNiklaus Nov 11, 2024
dabb4a7
Added diskcache.
JoelNiklaus Nov 13, 2024
65f759c
Merge branch 'main' into add_litellm_inference
JoelNiklaus Nov 20, 2024
f74afd4
Merge branch 'main' into add_litellm_inference
JoelNiklaus Nov 22, 2024
88a9838
Fix issue for openai evaluation.
JoelNiklaus Nov 25, 2024
02ed461
Added support for stop sequences and generation size.
JoelNiklaus Nov 26, 2024
34596c2
Merge branch 'main' into add_litellm_inference
JoelNiklaus Nov 26, 2024
190738f
Fixed issue with too many concurrent calls to APIs.
JoelNiklaus Nov 27, 2024
2bb1917
Merge branch 'main' into add_litellm_inference
clefourrier Nov 28, 2024
81e4404
Merge branch 'main' into add_litellm_inference
JoelNiklaus Dec 4, 2024
ebdd900
Merge branch 'main' into add_litellm_inference
NathanHB Dec 5, 2024
251e181
few fixes
NathanHB Dec 6, 2024
47b1888
Fixed issues with stop_sequence, max_completion_tokens and system_pro…
JoelNiklaus Dec 9, 2024
20a1191
Merge branch 'main' into add_litellm_inference
JoelNiklaus Dec 9, 2024
ade8f0c
Revert weird change to __main__.py.
JoelNiklaus Dec 9, 2024
a2587d6
Made configuration simpler.
JoelNiklaus Dec 9, 2024
7c0856e
Merge branch 'main' into add_litellm_inference
JoelNiklaus Dec 12, 2024
932fd2c
Fixed import issues.
JoelNiklaus Dec 12, 2024
8fc9b13
Merge branch 'main' into add_litellm_inference
NathanHB Dec 16, 2024
45d6d1d
fix import location
NathanHB Dec 16, 2024
2a23836
Merge branch 'add_litellm_inference' of github.com:JoelNiklaus/lighte…
NathanHB Dec 16, 2024
cca1446
Merge branch 'main' into add_litellm_inference
JoelNiklaus Dec 16, 2024
1a10351
Enabled passing through system prompt to the models in the requests.
JoelNiklaus Dec 16, 2024
ff6d5de
Fixed some bugs.
JoelNiklaus Dec 17, 2024
78789c1
allow bette rmessage managment for litellm
NathanHB Dec 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ dependencies = [
]

[project.optional-dependencies]
litellm = ["litellm", "diskcache"]
tgi = ["text-generation==0.6.0"]
optimum = ["optimum==1.12.0"]
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
Expand Down
109 changes: 109 additions & 0 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,112 @@ def tgi(
pipeline.save_and_push_results()

return results


@app.command(rich_help_panel="Evaluation Backends")
def litellm(
# === general ===
model_name: Annotated[
str, Argument(help="The model name to evaluate (has to be available through the litellm API.")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
] = False,
system_prompt: Annotated[
Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
] = None,
dataset_loading_processes: Annotated[
int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
custom_tasks: Annotated[
Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
cache_dir: Annotated[
str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANEL_NAME_1)
] = CACHE_DIR,
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1)
] = 1,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
] = "results",
push_to_hub: Annotated[
bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
push_to_tensorboard: Annotated[
bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
public_run: Annotated[
bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
results_org: Annotated[
Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANEL_NAME_2)
] = None,
save_details: Annotated[
bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANEL_NAME_2)
] = False,
# === debug ===
max_samples: Annotated[
Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3)
] = None,
override_batch_size: Annotated[
int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANEL_NAME_3)
] = -1,
job_id: Annotated[
int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANEL_NAME_3)
] = 0,
):
"""
Evaluate models using LiteLLM as backend.
"""

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.litellm_model import LiteLLMModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
evaluation_tracker = EvaluationTracker(
output_dir=output_dir,
save_details=save_details,
push_to_hub=push_to_hub,
push_to_tensorboard=push_to_tensorboard,
public=public_run,
hub_results_org=results_org,
)

# TODO (nathan): better handling of model_args
parallelism_manager = ParallelismManager.NONE

model_config = LiteLLMModelConfig(model=model_name)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
env_config=env_config,
job_id=job_id,
dataset_loading_processes=dataset_loading_processes,
custom_tasks_directory=custom_tasks,
override_batch_size=override_batch_size,
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
)
pipeline = Pipeline(
tasks=tasks,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config,
)

pipeline.evaluate()

pipeline.show_results()

results = pipeline.get_results()

pipeline.save_and_push_results()

return results
1 change: 0 additions & 1 deletion src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def greedy_until(

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
Expand Down
268 changes: 268 additions & 0 deletions src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional

from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel
from lighteval.models.endpoints.endpoint_model import ModelInfo
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)
from lighteval.utils.imports import is_litellm_available


logger = logging.getLogger(__name__)

if is_litellm_available():
import litellm
from litellm.caching.caching import Cache

logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("LiteLLM").handlers.clear()

litellm.cache = Cache(type="disk")


@dataclass
class LiteLLMModelConfig:
model: str


class LiteLLMClient(LightevalModel):
_DEFAULT_MAX_LENGTH: int = 4096

def __init__(self, config, env_config) -> None:
"""
IMPORTANT: Your API keys should be set in the environment variables.
If a base_url is not set, it will default to the public API.
"""
self.model_info = ModelInfo(
model_name=config.model,
model_sha="",
model_dtype=None,
model_size="",
)
self.provider = config.model.split("/")[0]
self.base_url = os.getenv(f"{self.provider.upper()}_BASE_URL", None)
self.API_MAX_RETRY = 5
self.API_RETRY_SLEEP = 3
self.API_RETRY_MULTIPLIER = 2
self.CONCURENT_CALLS = 20 # 100 leads to hitting Anthropic rate limits
self.TEMPERATURE = 0.7
self.TOP_P = 0.95
self.model = config.model
self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility
self.pairwise_tokenization = False
litellm.drop_params = True
litellm.verbose = True

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt):
for attempt in range(self.API_MAX_RETRY):
try:
if self.provider == "anthropic":
# Filter out whitespace-only stop sequences
if stop_sequence:
stop_sequence = [s for s in stop_sequence if s.strip()]
if not stop_sequence: # If empty after filtering
stop_sequence = ["\n"]

# Handle max_new_tokens
completion_tokens = None
if max_new_tokens and max_new_tokens > 0:
completion_tokens = max_new_tokens
if "o1" in self.model:
# We need to allow more tokens to include reasoning tokens
completion_tokens = min(max_new_tokens * 10, 32000)

response = litellm.completion(
model=self.model,
messages=prompt,
max_completion_tokens=completion_tokens,
logprobs=return_logits if self.provider == "openai" else None,
stop=stop_sequence,
base_url=self.base_url,
n=num_samples,
temperature=self.TEMPERATURE,
top_p=self.TOP_P,
caching=True,
)
return response
except Exception as e:
wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s
logger.warning(
f"Error in API call: {e}, waiting {wait_time} seconds before retry {attempt + 1}/{self.API_MAX_RETRY}"
)
time.sleep(wait_time)

logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, skipping entry.")

def __call_api_parallel(
self,
prompts,
return_logits: bool | list[bool],
max_new_tokens: int | list[int],
num_samples: int | list[int],
stop_sequence: list[str] | None = None,
system_prompt: str | list[str] = None,
):
results = []

return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits
max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens
num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples
stop_sequencess = [stop_sequence for _ in prompts]
system_prompts = [system_prompt for _ in prompts] if not isinstance(system_prompt, list) else system_prompt
assert (
len(prompts)
== len(return_logitss)
== len(max_new_tokenss)
== len(num_sampless)
== len(stop_sequencess)
== len(system_prompts)
), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences, system_prompts should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}, {len(system_prompts)}"

with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor:
for entry in tqdm(
executor.map(
self.__call_api,
prompts,
return_logitss,
max_new_tokenss,
num_sampless,
stop_sequencess,
system_prompts,
),
total=len(prompts),
):
results.append(entry)

if None in results:
raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.")

return results

def greedy_until(
self,
requests: list[GreedyUntilRequest],
override_bs: Optional[int] = None,
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
list[GenerativeResponse]: list of generated responses.
"""
for request in requests:
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
results = []

for _ in tqdm(
dataset.splits_start_end_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False, # self.disable_tqdm,
):
contexts = [c.context for c in dataset]
max_new_tokens = dataset[0].generation_size # could be none
return_logits = dataset[0].use_logits
num_samples = dataset[0].num_samples
stop_sequence = requests[0].stop_sequence
system_prompt = requests[0].system_prompt

responses = self.__call_api_parallel(
contexts, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt
)

for response in responses:
result: list[str] = [choice.message.content for choice in response.choices]

cur_response = GenerativeResponse(
result=result,
logits=None,
generated_tokens=[],
input_tokens=[],
)
results.append(cur_response)

return dataset.get_original_order(results)

@property
def tokenizer(self):
return self._tokenizer

def tok_encode(self, text: str):
return text

@property
def add_special_tokens(self) -> bool:
return False

@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
return 4096

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
raise NotImplementedError

def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError
Loading
Loading