From 4d15fda90c5ec5ad8c36e67291b4e6fe97d02e03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Fri, 29 Sep 2023 17:11:32 +0200 Subject: [PATCH] model_run_utils: add STOP_SEQUENCE finish reason MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- caikit_nlp/toolkit/text_generation/model_run_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index da75528e..df1671cd 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -239,6 +239,13 @@ def generate_text_func( finish_reason = "EOS_TOKEN" elif generate_ids.size(1) - 1 == max_new_tokens: finish_reason = "MAX_TOKENS" + elif ("stopping_criteria" in gen_optional_params) and ( + gen_optional_params["stopping_criteria"]( + generate_ids, + None, # scores, unused by SequenceStoppingCriteria + ) + ): + finish_reason = "STOP_SEQUENCE" else: finish_reason = "OTHER" return GeneratedTextResult(