Skip to content

Commit

Permalink
allow bette rmessage managment for litellm
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanHB committed Dec 17, 2024
1 parent ff6d5de commit 78789c1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se

response = litellm.completion(
model=self.model,
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
messages=prompt,
max_completion_tokens=completion_tokens,
logprobs=return_logits if self.provider == "openai" else None,
stop=stop_sequence,
Expand Down Expand Up @@ -234,7 +234,7 @@ def tokenizer(self):
return self._tokenizer

def tok_encode(self, text: str):
return self.tokenizer.encode(text)
return text

@property
def add_special_tokens(self) -> bool:
Expand Down
20 changes: 16 additions & 4 deletions src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import TYPE_CHECKING, Optional, Tuple, Union

from lighteval.models.abstract_model import LightevalModel
from lighteval.models.litellm_model import LiteLLMClient
from lighteval.tasks.requests import Doc
from lighteval.utils.utils import as_list

Expand Down Expand Up @@ -205,7 +206,10 @@ def _single_turn_context(
system_prompt=system_prompt,
use_chat_template=use_chat_template,
)
toks = self.model.tok_encode(output)
if not use_chat_template:
toks = self.model.tok_encode(output)
else:
toks = "".join([msg["content"] for msg in output])

# If we need to truncate few-shots to fit in the context
if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None:
Expand All @@ -223,9 +227,17 @@ def _single_turn_context(
system_prompt=system_prompt,
use_chat_template=use_chat_template,
)
toks = self.model.tokenizer(output)["input_ids"]
if not use_chat_template:
toks = self.model.tok_encode(output)
else:
toks = "".join([msg["content"] for msg in output])

if isinstance(self.model, LiteLLMClient):
return output, num_effective_fewshots

return output, num_effective_fewshots
return self.model.tokenizer.apply_chat_template(
output, tokenize=False, add_generation_prompt=True
), num_effective_fewshots

def get_examples(
self,
Expand Down Expand Up @@ -256,7 +268,7 @@ def get_examples(
examples.insert(0, {"role": "system", "content": system_prompt + instruction})
else: # Else we add the instruction to the first example
examples[0]["content"] = instruction + examples[0]["content"]
return self.model.tokenizer.apply_chat_template(examples, tokenize=False, add_generation_prompt=True)
return examples
else:
if system_prompt is not None:
output = system_prompt + instruction + "\n\n".join(examples)
Expand Down

0 comments on commit 78789c1

Please sign in to comment.