diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 61b78efb..ae7ecdf5 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -213,10 +213,10 @@ def run( seed: Optional[np.uint64] = None, preserve_input_text: bool = False, input_tokens: bool = False, - generated_tokens: bool = True, - token_logprobs: bool = True, - token_ranks: bool = True, - include_stop_sequence: Optional[bool] = True, + generated_tokens: bool = False, + token_logprobs: bool = False, + token_ranks: bool = False, + include_stop_sequence: Optional[bool] = None, context: Optional[RuntimeServerContextType] = None, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -280,10 +280,10 @@ def run_stream_out( seed: Optional[np.uint64] = None, preserve_input_text: bool = False, input_tokens: bool = False, - generated_tokens: bool = True, - token_logprobs: bool = True, - token_ranks: bool = True, - include_stop_sequence: Optional[bool] = True, + generated_tokens: bool = False, + token_logprobs: bool = False, + token_ranks: bool = False, + include_stop_sequence: Optional[bool] = None, context: Optional[RuntimeServerContextType] = None, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing against the model running in TGIS diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 2c458d27..d4626ec0 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -235,10 +235,10 @@ def run( seed: Optional[np.uint64] = None, preserve_input_text: bool = False, input_tokens: bool = False, - generated_tokens: bool = True, - token_logprobs: bool = True, - token_ranks: bool = True, - include_stop_sequence: Optional[bool] = True, + generated_tokens: bool = False, + token_logprobs: bool = False, + token_ranks: bool = False, + include_stop_sequence: Optional[bool] = None, context: Optional[RuntimeServerContextType] = None, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -296,10 +296,10 @@ def run_stream_out( seed: Optional[np.uint64] = None, preserve_input_text: bool = False, input_tokens: bool = False, - generated_tokens: bool = True, - token_logprobs: bool = True, - token_ranks: bool = True, - include_stop_sequence: Optional[bool] = True, + generated_tokens: bool = False, + token_logprobs: bool = False, + token_ranks: bool = False, + include_stop_sequence: Optional[bool] = None, context: Optional[RuntimeServerContextType] = None, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing for text generation module. diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 0bb7d70e..0b870ba6 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -144,7 +144,10 @@ def validate_inf_params( error.type_check("", bool, token_logprobs=token_logprobs) error.type_check("", bool, token_ranks=token_ranks) error.type_check( - "", bool, include_stop_sequence=include_stop_sequence + "", + bool, + allow_none=True, + include_stop_sequence=include_stop_sequence, ) error.type_check("", str, allow_none=True, eos_token=eos_token) error.type_check(