Skip to content

Commit

Permalink
Update local_infer.py
Browse files Browse the repository at this point in the history
avoid exception on very long text
  • Loading branch information
baoguangsheng committed Jul 16, 2024
1 parent 6041d57 commit 43c289b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scripts/local_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def run(args):
if len(text) == 0:
break
# evaluate text
tokenized = scoring_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
tokenized = scoring_tokenizer(text, truncation=True, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = scoring_model(**tokenized).logits[:, :-1]
if args.reference_model_name == args.scoring_model_name:
logits_ref = logits_score
else:
tokenized = reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
tokenized = reference_tokenizer(text, truncation=True, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
logits_ref = reference_model(**tokenized).logits[:, :-1]
crit = criterion_fn(logits_ref, logits_score, labels)
Expand Down

0 comments on commit 43c289b

Please sign in to comment.