Skip to content

Commit

Permalink
model_run_utils: return proper finish reason in generate_text_func
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Trifirò <[email protected]>
  • Loading branch information
dtrifiro committed Oct 2, 2023
1 parent eab5fcf commit bb8ca31
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion caikit_nlp/toolkit/text_generation/model_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ def generate_text_func(
for g in generate_ids
]

if generate_ids[0][-1].item() == eos_token:
if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or (
generate_ids[0, -1] == tokenizer.eos_token_id
):
finish_reason = "EOS_TOKEN"
elif generate_ids.size(1) - 1 == max_new_tokens:
finish_reason = "MAX_TOKENS"
Expand Down

0 comments on commit bb8ca31

Please sign in to comment.