Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Villanova del Moral <[email protected]>
Co-authored-by: Nathan Habib <[email protected]>
  • Loading branch information
3 people authored Dec 26, 2024
1 parent 90593a9 commit ff5026b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 39 deletions.
41 changes: 4 additions & 37 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import Optional


Expand Down Expand Up @@ -57,24 +57,7 @@ def from_dict(cls, config_dict: dict):
}
}
"""
if "generation" not in config_dict:
return GenerationParameters()
return GenerationParameters(
early_stopping=config_dict["generation"].get("early_stopping", None),
repetition_penalty=config_dict["generation"].get("repetition_penalty", None),
frequency_penalty=config_dict["generation"].get("frequency_penalty", None),
length_penalty=config_dict["generation"].get("length_penalty", None),
presence_penalty=config_dict["generation"].get("presence_penalty", None),
max_new_tokens=config_dict["generation"].get("max_new_tokens", None),
min_new_tokens=config_dict["generation"].get("min_new_tokens", None),
seed=config_dict["generation"].get("seed", None),
stop_tokens=config_dict["generation"].get("stop_tokens", None),
temperature=config_dict["generation"].get("temperature", None),
top_k=config_dict["generation"].get("top_k", None),
min_p=config_dict["generation"].get("min_p", None),
top_p=config_dict["generation"].get("top_p", None),
truncate_prompt=config_dict["generation"].get("truncate_prompt", None),
)
return GenerationParameters(**config_dict.get("generation", {}))

def to_vllm_openai_dict(self) -> dict:
"""Selects relevant generation and sampling parameters for vllm and openai models.
Expand All @@ -85,23 +68,7 @@ def to_vllm_openai_dict(self) -> dict:
"""
# Task specific sampling params to set in model: n, best_of, use_beam_search
# Generation specific params to set in model: logprobs, prompt_logprobs
args = {
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"seed": self.seed,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"stop": self.stop_tokens,
"max_tokens": self.max_new_tokens,
"min_tokens": self.min_new_tokens,
"truncate_prompt_tokens": self.truncate_prompt,
}
return {k: v for k, v in args.items() if v is not None}
return {k: v for k, v in asdict(self).items() if v is not None}

def to_transformers_dict(self) -> dict:
"""Selects relevant generation and sampling parameters for transformers models.
Expand All @@ -117,7 +84,7 @@ def to_transformers_dict(self) -> dict:
args = {
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"early_stopping": self.early_stopping or False,
"early_stopping": self.early_stopping,
"stop_strings": self.stop_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,9 +1350,9 @@ def _loglikelihood_single_token(

class BaseModel(TransformersModel):
def __post_init__(self):
super()
super().__post_init__()

logger.warning(
warnings.warn(
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!"
)

Expand Down

0 comments on commit ff5026b

Please sign in to comment.