From c89c5d4ed3e7ff29ade92689fc11e8949b2e37ef Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 8 Jul 2024 15:54:44 +0200 Subject: [PATCH] fix issue with `truncated_tokens_count` #2 --- src/lighteval/models/dummy_model.py | 7 +++---- src/lighteval/models/model_output.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/lighteval/models/dummy_model.py b/src/lighteval/models/dummy_model.py index ff7e1e43..4ea7b579 100644 --- a/src/lighteval/models/dummy_model.py +++ b/src/lighteval/models/dummy_model.py @@ -66,18 +66,17 @@ def greedy_until(self, requests: list[GreedyUntilRequest], override_bs: Optional def loglikelihood(self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None) -> list[ LoglikelihoodReturn]: - return [LoglikelihoodReturn((-random.random(), False), truncated_tokens_count=0, padded_tokens_count=0) + return [LoglikelihoodReturn((-random.random(), False)) for _ in requests] def loglikelihood_rolling(self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None) -> \ list[LoglikelihoodReturn]: - return [LoglikelihoodReturn((-random.random(), False), truncated_tokens_count=0, padded_tokens_count=0) + return [LoglikelihoodReturn((-random.random(), False)) for _ in requests] def loglikelihood_single_token(self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None) -> list[LoglikelihoodSingleTokenReturn]: return [ - LoglikelihoodSingleTokenReturn(result=[-random.random() for _ in req.tokenized_continuation], - truncated_tokens_count=0, padded_tokens_count=0) + LoglikelihoodSingleTokenReturn(result=[-random.random() for _ in req.tokenized_continuation]) for req in requests ] diff --git a/src/lighteval/models/model_output.py b/src/lighteval/models/model_output.py index 51027858..ce85c020 100644 --- a/src/lighteval/models/model_output.py +++ b/src/lighteval/models/model_output.py @@ -31,8 +31,8 @@ class ModelReturn: result: Union[tuple, list, str] input_tokens: list[int] = field(default_factory=list) # model inputs generated_tokens: list[int] = field(default_factory=list) # model generations - truncated_tokens_count: Optional[int] = None # How many tokens truncated - padded_tokens_count: Optional[int] = None # How many tokens of padding + truncated_tokens_count: Optional[int] = 0 # How many tokens truncated + padded_tokens_count: Optional[int] = 0 # How many tokens of padding def get_result_for_eval(self): raise NotImplementedError()