Skip to content

Commit

Permalink
fix test suite 2
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Feb 7, 2024
1 parent d9b262f commit cc1e7dd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met
raise ValueError(
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
)
choices_logprob = [sum(results[i].result) for i in range(len(formatted_doc.choices))]

choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
gold_ixs = as_list(formatted_doc.gold_index)

for metric in metrics:
Expand Down
22 changes: 8 additions & 14 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,32 +580,25 @@ def _loglikelihood_tokens(
max_equals = []
batch_cont_tokens = []
for cur_request, cur_logits, inplen in zip(batch, logits, prepared_batch.input_lengths):
cont_toks = cur_request.tokenized_continuation
cont_toks = torch.tensor(cur_request.tokenized_continuation, dtype=torch.long, device=self.device)
contlen = cont_toks.shape[0]
# We only look at the continuation tokens
contlen = len(cont_toks)
if contlen > inplen:
# Continuation is longer than the input size, we are in rolling mode (only continuation)
cur_logits = cur_logits.unsqueeze(0).to(self.device) # [1, seq, vocab]
cont_toks = (
torch.tensor(cont_toks, dtype=torch.long, device=self.device)[:inplen]
.unsqueeze(0)
.to(self.device)
) # [1, seq]
cont_toks = cont_toks[:inplen].unsqueeze(0).to(self.device) # [1, seq]
else:
cur_logits = (
cur_logits[inplen - contlen : inplen].unsqueeze(0).to(self.device)
) # [1, seq, vocab]
cont_toks = (
torch.tensor(cont_toks, dtype=torch.long, device=self.device).unsqueeze(0).to(self.device)
) # [1, seq]
) # [1, seq, voc]
cont_toks = cont_toks.unsqueeze(0).to(self.device) # [1, seq]

# Check if per-token argmax is exactly equal to continuation
greedy_tokens = cur_logits.argmax(dim=-1).to(self.device)
# Sometimes the continuation is longer than allowed by the model, we only look at the first tokens
max_equal = (greedy_tokens == cont_toks).all().squeeze(0).to(self.device)

# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
cur_logits = torch.gather(cur_logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]

# Answer: (log prob, is-exact-match)
Expand Down Expand Up @@ -640,6 +633,7 @@ def _loglikelihood_tokens(
zip(logits, batch_cont_tokens, max_equal, batched_inputs, batch_truncated, batch_padded)
):
answer = LoglikelihoodReturn(
# todo: we might want to store the logits unsummed
result=(float(logit.sum()), bool(maxe)) if return_bool_score else float(logit.sum()),
input_tokens=batched_input[: len_inputs[ix]].cpu().tolist(),
generated_tokens=cont_tokens[: len_tokens[ix]].cpu().tolist(),
Expand Down Expand Up @@ -695,10 +689,10 @@ def prepare_batch(

padded.append(padding_length - sequence_len)
# Right padding - it likely would be better to do left padding
tokens = F.pad(tokens, (0, padding_length - sequence_len), value=0)
tokens = F.pad(tokens, (0, padding_length - sequence_len), value=self.tokenizer.pad_token_id)

# We create the attention mask to ignore padding
mask = tokens == 0
mask = tokens == self.tokenizer.pad_token_id
attention_masks.append(mask)

input_tokens.append(tokens.unsqueeze(0)) # [1, padding_length]
Expand Down

0 comments on commit cc1e7dd

Please sign in to comment.