Skip to content

Commit

Permalink
added doc
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Dec 10, 2024
1 parent e3311bd commit a3f535f
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,19 @@ class GenerationParameters:
truncate_prompt: Optional[bool] = None # vllm, tgi

@classmethod
def from_dict(cls, config_dict):
def from_dict(cls, config_dict: dict):
"""Creates a GenerationParameters object from a config dictionary
Args:
config_dict (dict): Config dictionary. Must obey the following shape:
{"generation_parameters":
{
"early_stopping": value,
...
"truncate_prompt": value
}
}
"""
if "generation_parameters" not in config_dict:
return cls
cls.early_stopping = config_dict["generation_parameters"].get("early_stopping", None)
Expand All @@ -63,7 +75,13 @@ def from_dict(cls, config_dict):
cls.truncate_prompt = config_dict["generation_parameters"].get("truncate_prompt", None)
return cls

def to_vllm_openai_dict(self):
def to_vllm_openai_dict(self) -> dict:
"""Selects relevant generation and sampling parameters for vllm and openai models.
Doc: https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html
Returns:
dict: The parameters to create a vllm.SamplingParams or just provide OpenAI params as such in the model config.
"""
# Task specific sampling params to set in model: n, best_of, use_beam_search
# Generation specific params to set in model: logprobs, prompt_logprobs
args = {
Expand All @@ -84,7 +102,13 @@ def to_vllm_openai_dict(self):
}
return {k: v for k, v in args.items() if v is not None}

def to_transformers_dict(self):
def to_transformers_dict(self) -> dict:
"""Selects relevant generation and sampling parameters for transformers models.
Doc: https://huggingface.co/docs/transformers/v4.46.3/en/main_classes/text_generation#transformers.GenerationConfig
Returns:
dict: The parameters to create a transformers.GenerationConfig in the model config.
"""
# Task specific sampling params to set in model: do_sample, num_return_sequences, num_beans
args = {
"max_new_tokens": self.max_new_tokens,
Expand All @@ -104,7 +128,13 @@ def to_transformers_dict(self):
# we still create the object as it uses validation steps
return {k: v for k, v in args.items() if v is not None}

def to_tgi_inferenceendpoint_dict(self):
def to_tgi_inferenceendpoint_dict(self) -> dict:
"""Selects relevant generation and sampling parameters for tgi or inference endpoints models.
Doc: https://huggingface.co/docs/huggingface_hub/v0.26.3/en/package_reference/inference_types#huggingface_hub.TextGenerationInputGenerateParameters
Returns:
dict: The parameters to create a huggingface_hub.TextGenerationInputGenerateParameters in the model config.
"""
# Task specific sampling params to set in model: best_of, do_sample
args = {
"decoder_input_details": True,
Expand Down

0 comments on commit a3f535f

Please sign in to comment.