diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index da75528e..df1671cd 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -239,6 +239,13 @@ def generate_text_func( finish_reason = "EOS_TOKEN" elif generate_ids.size(1) - 1 == max_new_tokens: finish_reason = "MAX_TOKENS" + elif ("stopping_criteria" in gen_optional_params) and ( + gen_optional_params["stopping_criteria"]( + generate_ids, + None, # scores, unused by SequenceStoppingCriteria + ) + ): + finish_reason = "STOP_SEQUENCE" else: finish_reason = "OTHER" return GeneratedTextResult(