From 7fff5c0c211d59fccf1bd72c7f883690eabe7e73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 28 Sep 2023 15:15:16 +0200 Subject: [PATCH 1/4] model_run_utils: add/fix type hints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- caikit_nlp/toolkit/text_generation/model_run_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index b4a7dc77..9b56e179 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -19,7 +19,7 @@ # Third Party from peft.peft_model import PeftModel -from transformers import StoppingCriteria, TextStreamer +from transformers import AutoModel, AutoTokenizer, StoppingCriteria, TextStreamer import numpy as np import torch @@ -131,10 +131,10 @@ def __iter__(self): def generate_text_func( - model, - tokenizer, + model: "Union[PeftModel, AutoModel]", + tokenizer: "AutoTokenizer", producer_id: ProducerId, - eos_token: str, + eos_token: Optional[str], text: str, max_new_tokens: Optional[int] = 20, min_new_tokens: Optional[int] = 0, From ac28f6a9416e5a66b2e94311735b6068248ba40e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Thu, 28 Sep 2023 17:54:24 +0200 Subject: [PATCH 2/4] model_run_utils: fix logic for EOS_TOKEN finish reason MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- caikit_nlp/toolkit/text_generation/model_run_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index 9b56e179..7bd72347 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -234,7 +234,9 @@ def generate_text_func( for g in generate_ids ] - if generate_ids[0][-1].item() == eos_token: + if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or ( + generate_ids[0, -1] == tokenizer.eos_token_id + ): finish_reason = "EOS_TOKEN" elif generate_ids.size(1) - 1 == max_new_tokens: finish_reason = "MAX_TOKENS" From 185ff2244a82818a8a1eaa6ed76d457bb7098249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Fri, 27 Oct 2023 12:20:23 +0200 Subject: [PATCH 3/4] model_run_utils: add STOP_SEQUENCE finish reason MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniele Trifirò --- caikit_nlp/toolkit/text_generation/model_run_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index 7bd72347..cde879d5 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -238,6 +238,13 @@ def generate_text_func( generate_ids[0, -1] == tokenizer.eos_token_id ): finish_reason = "EOS_TOKEN" + elif ("stopping_criteria" in gen_optional_params) and ( + gen_optional_params["stopping_criteria"]( + generate_ids, + None, # scores, unused by SequenceStoppingCriteria + ) + ): + finish_reason = "STOP_SEQUENCE" elif generate_ids.size(1) - 1 == max_new_tokens: finish_reason = "MAX_TOKENS" else: From 1b4581747d29c40f9ebe836cf0a25df67473c92f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Sat, 7 Oct 2023 12:39:47 +0200 Subject: [PATCH 4/4] model_run_utils: use MAX_TOKENS as default finish_reason in generate_text_func MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit "OTHER" is an invalid value for caikit.interfaces.nlp.data_model.text_generation.FinishReason, resulting in failed serialization of responses when querying the text generation endpoint. For `generate_text_func`, it is reasonable to assume that if the finish reason is not `EOS_TOKEN` or `STOP_SEQUENCE`, it must be `MAX_TOKENS`. Signed-off-by: Daniele Trifirò --- caikit_nlp/toolkit/text_generation/model_run_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index cde879d5..eb45064f 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -245,10 +245,9 @@ def generate_text_func( ) ): finish_reason = "STOP_SEQUENCE" - elif generate_ids.size(1) - 1 == max_new_tokens: - finish_reason = "MAX_TOKENS" else: - finish_reason = "OTHER" + finish_reason = "MAX_TOKENS" + return GeneratedTextResult( generated_tokens=token_count, generated_text=preds[0],