From bb8ca31f10d0e15db72d22bbfbe67de49402433c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 28 Sep 2023 17:54:24 +0200 Subject: [PATCH] model_run_utils: return proper finish reason in generate_text_func MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- caikit_nlp/toolkit/text_generation/model_run_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index 641c59f7..da75528e 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -233,7 +233,9 @@ def generate_text_func( for g in generate_ids ] - if generate_ids[0][-1].item() == eos_token: + if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or ( + generate_ids[0, -1] == tokenizer.eos_token_id + ): finish_reason = "EOS_TOKEN" elif generate_ids.size(1) - 1 == max_new_tokens: finish_reason = "MAX_TOKENS"